Esempio n. 1
0
    def _start_check_health_thread(self):
        # Use a dummy all-reduce as a barrier to wait for all workers to be up,
        # otherwise the check health may fail immediately.

        # Use array_ops.identity to create the dummy tensor so that we have a new
        # Tensor. If we use constant it may be a cached from on a /job:localhost
        # device, which will cause some code that relies on tensor.device to error.
        #
        # TODO(b/151232436): change to an explicit barrier if we have it.
        dummy_value = array_ops.identity([])
        logging.info("Waiting for the cluster, timeout = %s",
                     self._check_health_initial_timeout or "inf")
        try:
            self._host_cross_device_ops.reduce(
                reduce_util.ReduceOp.SUM,
                dummy_value,
                dummy_value,
                options=collective_util.Options(
                    timeout_seconds=self._check_health_initial_timeout,
                    implementation=collective_util.CommunicationImplementation.
                    RING))
            if context.is_async():
                context.async_wait()
        except errors.DeadlineExceededError:
            raise RuntimeError(
                "Timeout waiting for the cluster, timeout is %d seconds" %
                self._check_health_initial_timeout)
        logging.info("Cluster is ready.")
        self._check_health_thread_should_stop = threading.Event()
        # Start the thread as daemon to avoid it blocking the program from exiting.
        # We try best to shutdown the thread but __del__ is not guaranteed to be
        # called when program exists.
        self._check_health_thread = threading.Thread(target=self._check_health,
                                                     daemon=True)
        self._check_health_thread.start()
 def _benchmark_eager_apply(self,
                            label,
                            device_and_format,
                            defun=False,
                            execution_mode=None):
     with context.execution_mode(execution_mode):
         device, data_format = device_and_format
         model = resnet50.ResNet50(data_format)
         if defun:
             model.call = tf.function(model.call)
         batch_size = 64
         num_burn = 5
         num_iters = 30
         with tf.device(device):
             images, _ = resnet50_test_util.random_batch(
                 batch_size, data_format)
             for _ in xrange(num_burn):
                 model(images, training=False).cpu()
             if execution_mode:
                 context.async_wait()
             gc.collect()
             start = time.time()
             for _ in xrange(num_iters):
                 model(images, training=False).cpu()
             if execution_mode:
                 context.async_wait()
             self._report(label, start, num_iters, device, batch_size,
                          data_format)
Esempio n. 3
0
    def train_epoch_fn(compression_ctrl, model, epoch):
        model.reset_metrics()

        if model.train_function is None:
            model.train_function = model.make_train_function()
        _, iterator = next(data_handler.enumerate_epochs())

        callbacks.on_epoch_begin(epoch)
        with data_handler.catch_stop_iteration():
            for step in data_handler.steps():
                with trace.Trace('train',
                                 epoch_num=epoch,
                                 step_num=step,
                                 batch_size=None,
                                 _r=1):
                    callbacks.on_train_batch_begin(step)
                    tmp_logs = model.train_function(iterator)
                    if data_handler.should_sync:
                        context.async_wait()
                    logs = tmp_logs
                    end_step = step + data_handler.step_increment
                    callbacks.on_train_batch_end(end_step, logs)
                    if model.stop_training:
                        break

        if logs is None:
            raise ValueError('Expect x to be a non-empty array or dataset.')
        epoch_logs = copy.copy(logs)
        callbacks.on_epoch_end(epoch, epoch_logs)
Esempio n. 4
0
 def _check_health(self, device, group_key, instance_key):
     first = True
     # We need to use a large enough value so that the all-reduce forms a
     # complete RING. In RING implementation, when value is too small, the
     # all-reduce may degrade into broadcasts. This means that some worker
     # failure may not be detected.
     value = array_ops.ones((32, 32), dtype=dtypes.float32)
     while True:
         if self._check_health_thread_should_stop.is_set():
             return
         timeout = None
         if first:
             # For the first check health we set timeout since it may need to do
             # group resolution, which may hang if the cluster is never healthy.
             timeout = self._check_health_initial_timeout
             first = False
         try:
             # We use an dummy all-reduce as a way to check the health of a cluster.
             # For RING it should be able to detect failed workers in the cluster if
             # the values are large enough.
             #
             # We're not using CrossDeviceOps because we need to run it with
             # pre-allocated group and instance keys.
             #
             # TODO(b/151232436): Replace the reduce with a check health op once we
             # add that.
             with ops.device(device):
                 collective_ops.all_reduce(value,
                                           group_size=self._num_workers,
                                           group_key=group_key,
                                           instance_key=instance_key,
                                           merge_op="Add",
                                           final_op="Id",
                                           subdiv_offsets=[0],
                                           communication_hint="ring",
                                           timeout=timeout)
                 if context.is_async():
                     context.async_wait()
         except (errors.UnavailableError, errors.DeadlineExceededError,
                 errors.FailedPreconditionError,
                 errors.CancelledError) as e:
             # TODO(b/151232436): Always raise UnavailableError when a peer fails.
             # Now there could be many kinds of errors:
             # - Unavailable: when the peer is not reachable, e.g. it's down.
             # - FailedPrecondition: when the peer has restarted.
             # - DeadlineExceeded: when the first check health exceeds the deadline,
             #   e.g. the peers take too long to be ready.
             # - Cancelled: when failures in organic collectives aborts first,
             #   outgoing RPCs may be aborted with Cancelled.
             logging.error(
                 "Cluster check alive failed, aborting collectives")
             context.context().abort_collective_ops(
                 errors.UNAVAILABLE, "cluster check alive failed: %s" % e)
         except Exception as e:  # pylint: disable=broad-except
             logging.exception("Unexpected exception in check alive.")
             context.context().abort_collective_ops(
                 errors.INTERNAL,
                 "unexecpted exception in check alive: %s" % e)
             return
         time.sleep(self._check_health_interval)
Esempio n. 5
0
 def _start_check_health_thread(self):
     # Use a dummy all-reduce as a barrier to wait for all workers to be up,
     # otherwise the check health may fail immediately.
     #
     # TODO(b/151232436): change to an explicit barrier if we have it.
     dummy_value = ops.convert_to_tensor([])
     logging.info("Waiting for the cluster, timeout = %s",
                  self._check_health_initial_timeout or "inf")
     try:
         self._host_cross_device_ops.reduce(
             reduce_util.ReduceOp.SUM,
             dummy_value,
             dummy_value,
             experimental_hints=collective_util.Hints(
                 timeout_seconds=self._check_health_initial_timeout))
         if context.is_async():
             context.async_wait()
     except errors.DeadlineExceededError:
         raise RuntimeError(
             "Timeout waiting for the cluster, timeout is %d seconds" %
             self._check_health_initial_timeout)
     self._check_health_thread_should_stop = threading.Event()
     # Start the thread as daemon to avoid it blocking the program from exiting.
     # We try best to shutdown the thread but __del__ is not guaranteed to be
     # called when program exists.
     self._check_health_thread = threading.Thread(target=self._check_health,
                                                  daemon=True)
     self._check_health_thread.start()
Esempio n. 6
0
    def test_out_of_range_with_for_loop(self):

        with ops.device('/job:worker/task:0'):
            dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
            dataset = dataset.batch(1, drop_remainder=False)
            iterator = iter(dataset)
            v = variables.Variable(1.0)

        @def_function.function
        def train_step(iterator):
            i = next(iterator)
            v.assign_add(math_ops.reduce_mean(i))

        num_steps = 3
        for i in range(num_steps):
            try:
                with ops.device('/job:worker/task:0'):
                    train_step(iterator)
                if i == num_steps - 1:
                    context.async_wait()
            except errors.OutOfRangeError:
                context.async_clear_error()
                break

        self.assertAllEqual(v.numpy(), 4.0)
Esempio n. 7
0
    def testTwoExecutors(self):
        # Run an op on the main executor that by default uses StreamingEnqueue to
        # schedule the op to run on the remote async executor. This op produces an
        # error, i.e., division by zero, but will not be immediately caught due to
        # streaming enqueue.
        with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
            a = constant_op.constant(3)
            b = constant_op.constant(0)
            math_ops.div(a, b)

        # Run another op using another executor that disables streaming enqueue,
        # which would run the op using the tf_compute thread pool in the remote
        # worker. Since the op is not run in the same remotes async executor, it
        # will not carry back that error produced by the op above, even though this
        # op is executed synchronously.
        with context.executor_scope(
                executor.new_executor(enable_async=False,
                                      enable_streaming_enqueue=False)):
            with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
                c = constant_op.constant(4)
                d = constant_op.constant(2)
                self.assertEqual(math_ops.div(c, d).numpy(), 2)

        # Sync on the context to force to catch the error produced by the first op.
        with self.assertRaises(errors.InvalidArgumentError) as cm:
            context.async_wait()
        self.assertIn('division by zero', cm.exception.message)
Esempio n. 8
0
def stop():
  """Stop current profiling session and return its result.

  Returns:
    A binary string of tensorflow.tpu.Trace. User can write the string
    to file for offline analysis by tensorboard.

  Raises:
    ProfilerNotRunningError: If there is no active profiling session.
  """
  global _profiler
  global _run_num
  with _profiler_lock:
    if _profiler is None:
      raise ProfilerNotRunningError(
          'Cannot stop profiling. No profiler is running.')
    if context.default_execution_mode == context.EAGER_MODE:
      context.async_wait()
    with c_api_util.tf_buffer() as buffer_:
      pywrap_tensorflow.TFE_ProfilerSerializeToString(
          _profiler,
          buffer_)
      result = pywrap_tensorflow.TF_GetBuffer(buffer_)
    pywrap_tensorflow.TFE_DeleteProfiler(_profiler)
    _profiler = None
    _run_num += 1
  return result
Esempio n. 9
0
def dtensor_initialize_multi_client(
        enable_coordination_service: Optional[bool] = False) -> None:
    """Initializes Multi Client DTensor.

  The following environment variables controls the behavior of this function.
  If the variables are unset, DTensor will be configured to run in single-client
  mode.

  - DTENSOR_CLIENT_ID: integer, between 0 to num_clients - 1, to identify the
      client id of the current process.
  - DTENSOR_NUM_CLIENTS: integer, the number of clients.
  - DTENSOR_JOB_NAME: string, a hostname like string for the name of the dtensor
      job. The job name is used by TensorFlow in the job name section of
      the DeviceSpec.
  - DTENSOR_JOBS: string, a comma separated list. Each item in the list is
      of format `{hostname}:{port}` and the items must be sorted in alphabet
      order. The implication is the RPC port numbers of the clients from
      the same host must be ordered by the client ID.
      Examples of valid DTENSOR_JOBS values:
      - 4 clients on localhost:
        `localhost:10000,localhost:10001,localhost:10002,localhost:10003`
      - 2 clients on host1, 2 clients on host2
        `host1:10000,host1:10001,host2:10000,host2:10003`

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.
  """
    global _in_multi_client_mode
    assert context.executing_eagerly()

    _in_multi_client_mode = api.job_name() != 'localhost'

    if not _in_multi_client_mode and api.num_clients() != 1:
        raise ValueError(
            'DTENSOR_NUM_CLIENTS is set and not 1, while DTENSOR_JOB_NAME is '
            'set to localhost for single client mode.')

    # Collective GRPC servers are only necessary in multi-client setup.
    # Single clients can use local mode of collectives.
    if _in_multi_client_mode:
        if api.jobs() is None:
            raise ValueError(
                'DTENSOR_JOBS environment variable is required when'
                'using multi-client to properly set up communications between servers'
            )
        multi_client_util.initialize_multi_client_cluster(
            job_name=api.job_name(),
            dtensor_jobs=api.jobs(),
            client_id=api.client_id(),
            collective_leader=api.full_job_name(task_id=0),
            enable_coordination_service=enable_coordination_service,
            protocol='grpc')

    # Make sure the server change is fully propagated before returning.
    context.ensure_initialized()
    context.async_wait()
    context.context()._clear_caches()  # pylint: disable=protected-access
Esempio n. 10
0
    def tearDown(self):
        super().tearDown()
        # Make sure all async ops finish.
        context.async_wait()

        # TODO(hthu): Remove the reset once we fixed the CopyToMesh with
        # DefaultMesh placement issue.
        reset_dtensor()
    def adapt(self, data, batch_size=None, steps=None, reset_state=True):
        """Fits the state of the preprocessing layer to the data being passed.

    Arguments:
        data: The data to train on. It can be passed either as a tf.data
          Dataset, or as a numpy array.
        batch_size: Integer or `None`.
            Number of samples per state update.
            If unspecified, `batch_size` will default to 32.
            Do not specify the `batch_size` if your data is in the
            form of datasets, generators, or `keras.utils.Sequence` instances
            (since they generate batches).
        steps: Integer or `None`.
            Total number of steps (batches of samples)
            When training with input tensors such as
            TensorFlow data tensors, the default `None` is equal to
            the number of samples in your dataset divided by
            the batch size, or 1 if that cannot be determined. If x is a
            `tf.data` dataset, and 'steps' is None, the epoch will run until
            the input dataset is exhausted. When passing an infinitely
            repeating dataset, you must specify the `steps` argument. This
            argument is not supported with array inputs.
        reset_state: Optional argument specifying whether to clear the state of
          the layer at the start of the call to `adapt`, or whether to start
          from the existing state. This argument may not be relevant to all
          preprocessing layers: a subclass of PreprocessingLayer may choose to
          throw if 'reset_state' is set to False.
    """
        _disallow_inside_tf_function('adapt')
        if not version_utils.should_use_v2():
            raise RuntimeError('`adapt` is only supported in tensorflow v2.')  # pylint: disable=g-doc-exception
        if not self.stateful:
            return
        if not self.streaming and self._is_adapted and not reset_state:
            raise ValueError(
                '{} does not supporting calling `adapt` twice without '
                'resetting the state.'.format(self.__class__.__name__))
        if not self._is_compiled:
            self.compile()  # Compile with defaults.
        if self.built and reset_state:
            self.reset_state()
        data_handler = data_adapter.DataHandler(
            data,
            batch_size=batch_size,
            steps_per_epoch=steps,
            epochs=1,
            steps_per_execution=self._steps_per_execution,
            distribute=False)
        self._adapt_function = self.make_adapt_function()
        for _, iterator in data_handler.enumerate_epochs():
            with data_handler.catch_stop_iteration():
                for _ in data_handler.steps():
                    self._adapt_function(iterator)
                    if data_handler.should_sync:
                        context.async_wait()
        self.finalize_state()
        self._is_adapted = True
Esempio n. 12
0
 def _apply(self, defun=False, execution_mode=None):
     device, data_format = resnet50_test_util.device_and_data_format()
     model = resnet50.ResNet50(data_format)
     if defun:
         model.call = tf.function(model.call)
     with tf.device(device), context.execution_mode(execution_mode):
         images, _ = resnet50_test_util.random_batch(2, data_format)
         output = model(images, training=False)
         context.async_wait()
     self.assertEqual((2, 1000), output.shape)
Esempio n. 13
0
def barrier(mesh: layout.Mesh, barrier_name: Optional[str] = None):
    """Runs a barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients. Currently we allocate a fully
  sharded tensor with mesh shape and run an all_reduce on it.

  Example:

  A barrier can be used before application exit to ensure completion of pending
  ops.

  ```python

  x = [1, 2, 3]
  x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
  dtensor.barrier(mesh)

  # At this point all devices on all clients in the mesh have completed
  # operations before the barrier. Therefore it is OK to tear down the clients.
  sys.exit()
  ```

  Args:
    mesh: The mesh to run the barrier on.
    barrier_name: The name of the barrier. mainly used for logging purpose.
  """
    if barrier_name is None:
        barrier_name = '(barrier)'

    logging.info('entering barrier before op: %s', barrier_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    # Reduction on a fully sharded tensor requires all devices to participate
    # and serves as a barrier on the mesh.
    component = array_ops.reshape(1.0, [1] * len(mesh.shape()))
    ones = api.pack([component] * mesh.num_local_devices(),
                    layout.Layout(mesh.dim_names, mesh))

    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info('finished running barrier across all clients after '
                 'op: %s', barrier_name)
Esempio n. 14
0
    def _test_train(self, execution_mode=None):
        start = time.process_time()
        model = mnist.custom_model()

        with tf.device("CPU"), context.execution_mode(execution_mode):
            optimizer = tf.keras.optimizers.SGD(0.1)
            images, labels = random_batch(1000)
            apply_gradients(model, optimizer,
                            compute_gradients(model, images, labels))
            context.async_wait()
        end = time.process_time()
        print("time: ", end - start)
Esempio n. 15
0
  def testMultiDeviceFunctionVariable(self):
    with ops.device('/job:worker/replica:0/task:0/cpu:0'):
      variable_b = variables.Variable(1)

    # Add a sync point to avoid the out-of-order issue of eager async execution
    # (b/155789951).
    context.async_wait()

    @def_function.function
    def with_variable(i):
      return i + variable_b

    self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
Esempio n. 16
0
 def _test_train(self, execution_mode=None):
     start = time.process_time()
     device, data_format = device_and_data_format()
     model = resnet50.ResNet50(data_format)
     for i in range(10):
         with tf.device(device), context.execution_mode(execution_mode):
             optimizer = tf.keras.optimizers.SGD(0.1)
             images, labels = random_batch(32, data_format)
             apply_gradients(model, optimizer,
                             compute_gradients(model, images, labels))
             context.async_wait()
     end = time.process_time()
     print("time: ", end - start)
Esempio n. 17
0
    def testCopyBetweenDevicesAsync(self):
        with context.execution_mode(context.ASYNC):
            x = constant_op.constant([[1., 2.], [3., 4.]])
            x = x.cpu()
            x = x.gpu()
            x = x.gpu()
            x = x.cpu()
            context.async_wait()

        # Invalid device
        with self.assertRaises(RuntimeError):
            x.gpu(context.context().num_gpus() + 1)
            context.async_wait()
        context.async_clear_error()
Esempio n. 18
0
  def testCopyBetweenDevicesAsync(self):
    with context.execution_mode(context.ASYNC):
      x = constant_op.constant([[1., 2.], [3., 4.]])
      x = x.cpu()
      x = x.gpu()
      x = x.gpu()
      x = x.cpu()
      context.async_wait()

    # Invalid device
    with self.assertRaises(RuntimeError):
      x.gpu(context.context().num_gpus() + 1)
      context.async_wait()
    context.async_clear_error()
Esempio n. 19
0
  def testAsyncWaitIsNoOp(self):
    if self.num_workers < 2:
      self.skipTest("Worker number is less than 2.")
    model = self._create_model_and_run_indefinitely()

    self.assertFalse(self.cluster_coord.done())
    self._cluster.kill_task("worker", 0)
    time.sleep(2)
    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
    # Should pass without exception even with failed remote workers
    context.async_wait()

    model.join_training_functions()
    self.assertGreaterEqual(model.iterations.numpy(), 10)
Esempio n. 20
0
def dtensor_initialize_multi_client(
        enable_coordination_service: Optional[bool] = False) -> None:
    """Initializes Multi Client DTensor.

  The following environment variables controls the behavior of this function.
  If the variables are unset, DTensor will be configured to run in single-client
  mode.

  - DTENSOR_CLIENT_ID: integer, between 0 to num_clients - 1, to identify the
      client id of the current process. The default value is 0.
  - DTENSOR_NUM_CLIENTS: integer, the number of clients. The default value is 1.
  - DTENSOR_JOB_NAME: string, a hostname like string for the name of the dtensor
      job. The default is `localhost` when number of clients is 1, and `worker`
      when the number of clients is greater than 1.
      The job name controls the job name section of the TensorFlow DeviceSpecs,
      e.g., `job:worker` in `/job:worker/replica:0/task:0/device:TPU:0` when
      the job name is `worker`.
  - DTENSOR_JOBS: string, a comma separated list. Each item in the list is
      of format `{hostname}:{port}` and the items must be sorted in alphabet
      order. The implication is the RPC port numbers of the clients from
      the same host must be ordered by the client ID.
      Examples of valid DTENSOR_JOBS values:
      - 4 clients on localhost:
        `localhost:10000,localhost:10001,localhost:10002,localhost:10003`
      - 2 clients on host1, 2 clients on host2
        `host1:10000,host1:10001,host2:10000,host2:10003`

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.
  """
    assert context.executing_eagerly()

    # Collective GRPC servers are only necessary in multi-client setup.
    # Single clients can use local mode of collectives.
    if api.num_clients() > 1:
        multi_client_util.initialize_multi_client_cluster(
            job_name=api.job_name(),
            dtensor_jobs=api.jobs(),
            client_id=api.client_id(),
            collective_leader=api.full_job_name(task_id=0),
            enable_coordination_service=enable_coordination_service)

    # Make sure the server change is fully propagated before returning.
    context.ensure_initialized()
    context.async_wait()
    context.context()._clear_caches()  # pylint: disable=protected-access
Esempio n. 21
0
    def _get_iterator(self):
        worker_iterators = _create_iterators_per_worker(
            self._cloned_datasets, self._input_workers, self._options)
        cardinality = input_lib._cardinality(self._cloned_datasets[0])  # pylint: disable=protected-access
        iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
                                         self._strategy, cardinality,
                                         self._enable_get_next_as_optional)
        iterator._element_spec = self.element_spec  # pylint: disable=protected-access

        # When async eager is enabled, sometimes the iterator may not finish
        # initialization before passing to a multi device function, add a sync point
        # here to make sure all underlying iterators are initialized.
        if context.executing_eagerly():
            context.async_wait()

        return iterator
Esempio n. 22
0
    def testOperationTimeout(self):
        context._reset_context()
        context.context().operation_timeout_in_ms = 10
        workers, _ = test_util.create_local_cluster(1, 0)
        remote.connect_to_remote_host(workers[0].target)

        q = data_flow_ops.FIFOQueue(1, dtypes.int32)

        @def_function.function
        def f():
            return q.dequeue()

        with self.assertRaises(errors.DeadlineExceededError):
            with ops.device('/job:worker/replica:0/task:0'):
                f()
            # If streaming RPC is enabled, fetch remote errors before end of execution
            context.async_wait()
Esempio n. 23
0
 def test_checkpointing(self):
     prefix = os.path.join(self.get_temp_dir(), "ckpt")
     with self.device.scope():
         different_values = self.device.pack(
             [constant_op.constant(-1.),
              constant_op.constant(3.)])
         v = variables.Variable(different_values)
         checkpoint = tracking.Checkpoint(v=v)
     save_path = checkpoint.save(prefix)
     with ops.device(self.device.name):
         v.assign(constant_op.constant(0.))
     # Make sure the checkpoint is actually written before we try to read it
     context.async_wait()
     checkpoint.restore(save_path).assert_consumed()
     with ops.device(self.device.name):
         outputs = self.device.unpack(v)
     self.assertAllClose([-1., 3.], outputs)
Esempio n. 24
0
    def _benchmark_eager_train(self,
                               label,
                               make_iterator,
                               device_and_format,
                               defun=False,
                               execution_mode=None):
        with context.execution_mode(execution_mode):
            device, data_format = device_and_format
            for batch_size in self._train_batch_sizes():
                (images, labels) = resnet50_test_util.random_batch(
                    batch_size, data_format)
                model = resnet50.ResNet50(data_format)
                # TODO(b/161911585): tf_to_corert MLIR lowering pipeline should handle
                # case when momentum is not set.
                optimizer = tf.keras.optimizers.SGD(0.1, 0.1)
                apply_grads = apply_gradients
                if defun:
                    model.call = tf.function(model.call)
                    apply_grads = tf.function(apply_gradients)

                num_burn = 3
                num_iters = 10
                with tf.device(device):
                    iterator = make_iterator((images, labels))
                    for _ in xrange(num_burn):
                        (images, labels) = iterator.next()
                        apply_grads(model, optimizer,
                                    compute_gradients(model, images, labels))
                    if execution_mode:
                        context.async_wait()
                    self._force_device_sync()
                    gc.collect()

                    start = time.time()
                    for _ in xrange(num_iters):
                        (images, labels) = iterator.next()
                        apply_grads(model, optimizer,
                                    compute_gradients(model, images, labels))
                    if execution_mode:
                        context.async_wait()
                    self._force_device_sync()
                    self._report(label, start, num_iters, device, batch_size,
                                 data_format)
Esempio n. 25
0
 def _test_train(self, execution_mode=None):
     device, data_format = device_and_data_format()
     model = resnet50.ResNet50(data_format)
     tf.compat.v2.summary.experimental.set_step(
         tf.train.get_or_create_global_step())
     logdir = tempfile.mkdtemp()
     with tf.compat.v2.summary.create_file_writer(
             logdir, max_queue=0,
             name='t0').as_default(), tf.compat.v2.summary.record_if(True):
         with tf.device(device), context.execution_mode(execution_mode):
             optimizer = tf.train.GradientDescentOptimizer(0.1)
             images, labels = random_batch(2, data_format)
             apply_gradients(model, optimizer,
                             compute_gradients(model, images, labels))
             self.assertEqual(320, len(model.variables))
             context.async_wait()
     events = events_from_logdir(logdir)
     self.assertEqual(len(events), 2)
     self.assertEqual(events[1].summary.value[0].tag, 'loss')
Esempio n. 26
0
 def testExecuteBasicAsync(self):
     with context.execution_mode(context.ASYNC):
         three = constant_op.constant(3)
         five = constant_op.constant(5)
         product = execute(b'Mul',
                           num_outputs=1,
                           inputs=[three, five],
                           attrs=('T', three.dtype.as_datatype_enum))[0]
         self.assertAllEqual(15, product)
     # Error: Invalid arguments
     context.set_execution_mode(context.ASYNC)
     with self.assertRaises(errors.InvalidArgumentError):
         execute(b'MatMul',
                 num_outputs=1,
                 inputs=[three, five],
                 attrs=('transpose_a', False, 'transpose_b', False, 'T',
                        three.dtype.as_datatype_enum))
         context.async_wait()
     context.async_clear_error()
     context.context().execution_mode = context.SYNC
Esempio n. 27
0
    def testPyFunctionAsync(self):
        def simple_fn(v):
            one = constant_op.constant(1.)
            return v + one

        @def_function.function
        def test_fn(v):
            return script_ops.eager_py_func(simple_fn, [v], dtypes.float32)

        async_executor = executor.new_executor(enable_async=True)
        with context.executor_scope(async_executor):
            test_var = variables.Variable(2.)
            self.assertAllEqual(test_fn(test_var), 3.0)
        async_executor.wait()

        with context.executor_scope(async_executor):
            test_var = variables.Variable(2.)
            result = test_fn(test_var)
            context.async_wait()
            self.assertAllEqual(result, 3.0)
Esempio n. 28
0
    def _benchmark_eager_train(self,
                               label,
                               make_iterator,
                               device_and_format,
                               defun=False,
                               execution_mode=None):
        with context.execution_mode(execution_mode):
            device, data_format = device_and_format
            for batch_size in self._train_batch_sizes():
                (images, labels) = resnet50_test_util.random_batch(
                    batch_size, data_format)
                model = resnet50.ResNet50(data_format)
                optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
                apply_grads = apply_gradients
                if defun:
                    model.call = tf.function(model.call)
                    apply_grads = tf.function(apply_gradients)

                num_burn = 3
                num_iters = 10
                with tf.device(device):
                    iterator = make_iterator((images, labels))
                    for _ in xrange(num_burn):
                        (images, labels) = iterator.next()
                        apply_grads(model, optimizer,
                                    compute_gradients(model, images, labels))
                    if execution_mode:
                        context.async_wait()
                    self._force_device_sync()
                    gc.collect()

                    start = time.time()
                    for _ in xrange(num_iters):
                        (images, labels) = iterator.next()
                        apply_grads(model, optimizer,
                                    compute_gradients(model, images, labels))
                    if execution_mode:
                        context.async_wait()
                    self._force_device_sync()
                    self._report(label, start, num_iters, device, batch_size,
                                 data_format)
Esempio n. 29
0
 def catch_stop_iteration(self):
   """Catches errors when an iterator runs out of data."""
   try:
     yield
     context.async_wait()
   except (StopIteration, errors.OutOfRangeError):
     if (self._adapter.get_size() is None and self._inferred_steps is None and
         self._current_step > 0):
       # The input passed by the user ran out of batches.
       # Now we know the cardinality of the input(dataset or generator).
       self._inferred_steps = self._current_step
     else:
       self._insufficient_data = True
       total_epochs = self._epochs - self._initial_epoch
       logging.warning(
           "Your input ran out of data; interrupting training. "
           "Make sure that your dataset or generator can generate at "
           "least `steps_per_epoch * epochs` batches (in this case, "
           "{} batches). You may need to use the repeat() function "
           "when building your dataset.".format(total_epochs *
                                                self._inferred_steps))
Esempio n. 30
0
def _global_barrier(mesh: layout_lib.Mesh, last_op_name: str):
    """Runs a global barrier on the mesh.

  Upon returning from the barrier, all operations run before the barrier
  would have completed across all clients.

  Currently we allocate a fully sharded tensor with mesh shape and run a
  all_reduce on it.

  Args:
    mesh: The mesh to run the global barrier on.
    last_op_name: The last op run before the global_barrier. mainly used for
      logging purpose.
  """
    logging.info('entering global barrier before op: %s', last_op_name)

    # Make sure all ops are consumed before running the sync.
    context.async_wait()

    shape = api._dtensor_device().pack(  # pylint: disable=protected-access
        [mesh.shape()] * mesh.num_local_devices(),
        layout_lib.Layout.replicated(mesh, rank=1))
    ones = api.call_with_layout(array_ops.ones,
                                layout_lib.Layout(mesh.dim_names, mesh),
                                shape=shape,
                                dtype=dtypes.float32)
    mesh_size = math_ops.reduce_sum(ones)
    if mesh_size != mesh.size:
        raise ValueError(
            'Global barrier produced wrong mesh size : {0} while mesh has actual'
            'size : {1}'.format(mesh_size, mesh.size))

    # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
    # from users. Consider dropping this if there is a `big` performance hit.
    context.async_wait()

    logging.info(
        'finished running global barrier across all clients after '
        'op: %s', last_op_name)
Esempio n. 31
0
 def testExecuteBasicAsync(self):
   with context.execution_mode(context.ASYNC):
     three = constant_op.constant(3)
     five = constant_op.constant(5)
     product = execute(
         b'Mul',
         num_outputs=1,
         inputs=[three, five],
         attrs=('T', three.dtype.as_datatype_enum))[0]
     self.assertAllEqual(15, product)
   # Error: Invalid arguments
   context.set_execution_mode(context.ASYNC)
   with self.assertRaises(errors.InvalidArgumentError):
     execute(
         b'MatMul',
         num_outputs=1,
         inputs=[three, five],
         attrs=('transpose_a', False, 'transpose_b', False, 'T',
                three.dtype.as_datatype_enum))
     context.async_wait()
   context.async_clear_error()
   context.set_execution_mode(context.SYNC)
Esempio n. 32
0
 def mc_sample(self,
               x,
               batch_size=None,
               steps=None,
               max_queue_size=10,
               workers=1,
               use_multiprocessing=False):
     outputs = None
     with self.distribute_strategy.scope():
         data_handler = data_adapter.DataHandler(
             x=x,
             batch_size=batch_size,
             steps_per_epoch=steps,
             initial_epoch=0,
             epochs=1,
             max_queue_size=max_queue_size,
             workers=workers,
             use_multiprocessing=use_multiprocessing,
             model=self)
         predict_function = self.make_mc_sample_function()
         for _, iterator in data_handler.enumerate_epochs():
             with data_handler.catch_stop_iteration():
                 for step in data_handler.steps():
                     tmp_batch_outputs = predict_function(iterator)
                     if not data_handler.inferred_steps:
                         context.async_wait()
                     batch_outputs = tmp_batch_outputs
                     if outputs is None:
                         outputs = nest.map_structure(
                             lambda batch_output: [batch_output],
                             batch_outputs)
                     else:
                         nest.map_structure_up_to(
                             batch_outputs,
                             lambda output, batch_output: output.append(
                                 batch_output), outputs, batch_outputs)
     all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
     return tf_utils.to_numpy_or_python_type(all_outputs)