예제 #1
0
    def _get_test_objects(self,
                          task_type,
                          task_id,
                          num_gpus=0,
                          use_strategy_object=False,
                          local_mode=False):
        collective_keys = cross_device_utils.CollectiveKeys(
            group_key_start=10 * num_gpus +
            CollectiveAllReduceTest.collective_key_base,
            instance_key_start=num_gpus * 100 +
            CollectiveAllReduceTest.collective_key_base,
            instance_key_with_id_start=num_gpus * 10000 +
            CollectiveAllReduceTest.collective_key_base)
        if local_mode:
            if num_gpus:
                devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
            else:
                devices = ["/device:CPU:0"]

            if use_strategy_object:
                strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
                )
                strategy.extended._collective_keys = collective_keys
                strategy.extended._cross_device_ops._collective_keys = collective_keys
                return strategy, devices, ""
            else:
                collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
                    1, num_gpus, collective_keys=collective_keys)
                return collective_all_reduce_ops, devices, ""
        else:
            if num_gpus:
                devices = [
                    "/job:%s/task:%d/replica:0/device:GPU:%d" %
                    (task_type, task_id, i) for i in range(num_gpus)
                ]
            else:
                devices = [
                    "/job:%s/task:%d/replica:0/device:CPU:0" %
                    (task_type, task_id)
                ]

            if use_strategy_object:
                strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
                )
                strategy.configure(cluster_spec=self._cluster_spec,
                                   task_type=task_type,
                                   task_id=task_id)
                strategy.extended._collective_keys = collective_keys
                strategy.extended._cross_device_ops._collective_keys = collective_keys
                return (strategy, devices,
                        "grpc://" + self._cluster_spec[task_type][task_id])
            else:
                collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
                    NUM_WORKERS, num_gpus, collective_keys=collective_keys)
                return (collective_all_reduce_ops, devices,
                        "grpc://" + self._cluster_spec[task_type][task_id])
예제 #2
0
 def _create_multi_worker_mirrored():
     tf_config = cluster_resolver.TFConfigClusterResolver()
     master = tf_config.master()
     if tf_config.rpc_layer:
         # Strip off the rpc_layer suffix.
         master = master[len("%s://" % tf_config.rpc_layer):]
     resolver = cluster_resolver.SimpleClusterResolver(
         cluster_spec=tf_config.cluster_spec(),
         task_type=tf_config.task_type,
         task_id=tf_config.task_id,
         master=master,
         environment=tf_config.environment,
         num_accelerators={"GPU": required_gpus},
         rpc_layer=tf_config.rpc_layer or "grpc",
     )
     # Always create the strategy in eager mode so that it starts the server and
     # configures the eager context. The eager context can no longer be
     # configured after initialization.
     with context.eager_mode():
         strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
             cluster_resolver=resolver)
     # TODO(b/152320929): Wait for the cluster before proceeding, otherwise
     # collectives may hang if any worker launches collectives before the chief
     # creates the strategy.
     try:
         multi_process_runner.barrier().wait()
     except ValueError:
         # If the creator is called in the main process,
         # multi_process_runner.barrier() raises ValueError, which is safe to
         # ignore.
         pass
     return strategy
    def testKeepLogicalDevice(self):
        gpus = tf_config.list_physical_devices('GPU')
        if len(gpus) > 1:
            self.skipTest(
                'Skip logical device test on multi GPUs, since partial GPU '
                'virtualization is not permitted.')
        # Cannot change logical device after the context initialization.
        context._reset_context()  # pylint: disable=protected-access
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=False, num_workers=1)
        resolver = cluster_resolver_lib.SimpleClusterResolver(
            cluster_spec=multi_worker_util.normalize_cluster_spec(
                cluster_spec),
            task_type='worker',
            task_id=0)

        logical_gpus = len(gpus) * 2
        for i, device in enumerate(gpus):
            n = (i +
                 1) * logical_gpus // len(gpus) - i * logical_gpus // len(gpus)
            assert n > 0  # guaranteed if count >= len(devices)
            configs = []
            for ordinal in range(n):
                config = context.LogicalDeviceConfiguration(
                    memory_limit=64, experimental_device_ordinal=ordinal)
                configs.append(config)

            tf_config.set_logical_device_configuration(device, configs)

        collective_all_reduce_strategy.CollectiveAllReduceStrategy(
            cluster_resolver=resolver)
        # Since we create two logical GPUs out of the last GPU, there should be one
        # more logical GPUs than physical GPUs.
        self.assertLen(tf_config.list_logical_devices('GPU'), logical_gpus)
        context._reset_context()  # pylint: disable=protected-access
def create_test_objects(cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        num_gpus=None,
                        num_tpus=None):
    if num_gpus is None:
        num_gpus = context.num_gpus()
    if num_tpus is None:
        num_tpus = context.context().list_physical_devices('TPU')
    if num_tpus:
        tpu_strategy_util.initialize_tpu_system()

    if cluster_spec and task_type and task_id is not None:
        cluster_resolver = SimpleClusterResolver(
            cluster_spec=multi_worker_util.normalize_cluster_spec(
                cluster_spec),
            task_type=task_type,
            task_id=task_id,
            num_accelerators={
                'GPU': num_gpus,
                'TPU': num_tpus
            })
        target = 'grpc://' + cluster_spec[task_type][task_id]
    else:
        cluster_resolver = SimpleClusterResolver(ClusterSpec({}),
                                                 num_accelerators={
                                                     'GPU': num_gpus,
                                                     'TPU': num_tpus
                                                 })
        target = ''

    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
        cluster_resolver=cluster_resolver)

    return strategy, target
def _model_setup(test_obj, file_format):
    """Set up a MNIST Keras model for testing purposes.

  This function builds a MNIST Keras model and returns relevant information
  for testing.

  Args:
    test_obj: The `TestCase` testing object.
    file_format: File format for checkpoints. 'tf' or 'h5'.

  Returns:
    A tuple of (model, saving_filepath, train_ds, steps) where train_ds is
    the training dataset.
  """
    batch_size = 64
    steps = 2
    with collective_strategy.CollectiveAllReduceStrategy().scope():
        # TODO(b/142509827): In rare cases this errors out at C++ level with the
        # "Connect failed" error message.
        train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
            batch_size, steps)
        model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
    # Pass saving_filepath from the parent thread to ensure every worker has the
    # same filepath to save.
    saving_filepath = os.path.join(test_obj.get_temp_dir(),
                                   'checkpoint.' + file_format)
    return model, saving_filepath, train_ds, steps
예제 #6
0
def create_test_objects(cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        num_gpus=None):
    sess_config = config_pb2.ConfigProto()
    if num_gpus is None:
        num_gpus = len(tf_config.list_logical_devices('GPU'))

    if cluster_spec and task_type and task_id is not None:
        cluster_resolver = SimpleClusterResolver(
            cluster_spec=multi_worker_util.normalize_cluster_spec(
                cluster_spec),
            task_type=task_type,
            task_id=task_id,
            num_accelerators={'GPU': num_gpus})
        target = 'grpc://' + cluster_spec[task_type][task_id]
    else:
        cluster_resolver = SimpleClusterResolver(
            ClusterSpec({}), num_accelerators={'GPU': num_gpus})
        target = ''

    strategy = mwms_lib.CollectiveAllReduceStrategy(
        cluster_resolver=cluster_resolver)
    sess_config = strategy.update_config_proto(sess_config)

    return strategy, target, sess_config
def _model_setup(test_obj, file_format):
    """Set up a MNIST Keras model for testing purposes.

  This function builds a MNIST Keras model and returns relevant information
  for testing.

  Args:
    test_obj: The `TestCase` testing object.
    file_format: File format for checkpoints. 'tf' or 'h5'.

  Returns:
    A tuple of (model, saving_filepath, train_ds, steps) where train_ds is
    the training dataset.
  """
    batch_size = 64
    steps = 2
    with collective_strategy.CollectiveAllReduceStrategy().scope():
        # TODO(b/142509827): In rare cases this errors out at C++ level with the
        # following error message:
        # subchannel.cc:1000] Connect failed: {"created":"@1570753640.827421717",
        # "description":"Failed to connect to remote host: Connection refused",
        # "errno":111,"file":"third_party/grpc/src/core/lib/iomgr/tcp_client_posix.cc",
        # "file_line":200,"os_error":"Connection refused","syscall":"connect",
        # "target_address":"ipv6:[::1]:17271"}
        train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
            batch_size, steps)
        model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
    # Pass saving_filepath from the parent thread to ensure every worker has the
    # same filepath to save.
    saving_filepath = os.path.join(test_obj.get_temp_dir(),
                                   'checkpoint.' + file_format)
    return model, saving_filepath, train_ds, steps
예제 #8
0
def _distribution_strategies():
  return [
      collective_all_reduce_strategy.CollectiveAllReduceStrategy(),
      mirrored_strategy.MirroredStrategy(),
      # TODO(pulkitb): Add parameter_server
      # parameter_server_strategy.ParameterServerStrategy(),
      one_device_strategy.OneDeviceStrategy('/cpu:0'),
  ]
예제 #9
0
  def _get_test_objects(self,
                        task_type,
                        task_id,
                        num_gpus=0,
                        communication=CollectiveCommunication.AUTO,
                        use_strategy_object=False,
                        local_mode=False):
    collective_keys = cross_device_utils.CollectiveKeys(
        group_key_start=10 + CollectiveAllReduceTest.collective_key_base)
    if local_mode:
      if num_gpus:
        devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
      else:
        devices = ["/device:CPU:0"]

      if use_strategy_object:
        strategy = (mwms_lib.CollectiveAllReduceStrategy
                    ._from_local_devices(devices, communication=communication))  # pylint: disable=protected-access
        return strategy, devices, ""
      else:
        collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
            devices=devices,
            group_size=len(devices),
            collective_keys=collective_keys)
        return collective_all_reduce_ops, devices, ""
    else:
      # NCCL requires physical GPUs for every replica, which we can't do with
      # simulated multi host set up now.
      assert communication != CollectiveCommunication.NCCL
      if num_gpus:
        devices = [
            "/job:%s/task:%d/replica:0/device:GPU:%d" % (task_type, task_id, i)
            for i in range(num_gpus)
        ]
      else:
        devices = [
            "/job:%s/task:%d/replica:0/device:CPU:0" % (task_type, task_id)
        ]

      if use_strategy_object:
        resolver = cluster_resolver.SimpleClusterResolver(
            cluster_spec=multi_worker_util.normalize_cluster_spec(
                self._cluster_spec),
            task_type=task_type,
            task_id=task_id,
            num_accelerators={"GPU": num_gpus})
        strategy = mwms_lib.CollectiveAllReduceStrategy(
            cluster_resolver=resolver, communication=communication)
        return (strategy, devices,
                "grpc://" + self._cluster_spec[task_type][task_id])
      else:
        collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
            devices=devices,
            group_size=len(devices) * NUM_WORKERS,
            collective_keys=collective_keys)
        return (collective_all_reduce_ops, devices,
                "grpc://" + self._cluster_spec[task_type][task_id])
예제 #10
0
  def test_dataset_creator_input_options(self):
    dataset_fn = lambda _: dataset_ops.DatasetV2.from_tensor_slices([1, 1])
    input_options = distribute_lib.InputOptions(
        experimental_fetch_to_device=True,
        experimental_per_replica_buffer_size=2)
    x = dataset_creator.DatasetCreator(dataset_fn, input_options=input_options)
    with collective_all_reduce_strategy.CollectiveAllReduceStrategy().scope():
      data_handler = data_adapter.get_data_handler(
          x,
          steps_per_epoch=2,
          model=sequential.Sequential([core_layers.Dense(10)]))

    # Ensuring the resulting `DistributedDatasetsFromFunction` has the right
    # options.
    self.assertTrue(data_handler._dataset._options.experimental_fetch_to_device)
    self.assertEqual(
        data_handler._dataset._options.experimental_per_replica_buffer_size, 2)
    def worker_step_fn(worker_id):
      strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
      # Make sure the processeses are in sync after updating the cluster
      multi_process_runner.get_barrier().wait()

      @def_function.function
      def run_reduce():
        with ops.device(self._local_device):
          t_in = array_ops.ones(tensor_shape) * worker_id
          return strategy.reduce(reduce_util.ReduceOp.MEAN, t_in, axis=None)

      t_out = run_reduce()
      # Element values from the workers are
      #     0, 1, ..., (NUM_WORKERS - 1)
      expected_mean = (NUM_WORKERS - 1) / 2
      expected_out = np.ones(tensor_shape) * expected_mean
      self.assertAllClose(t_out, expected_out)
    def worker_step_fn(worker_id, num_dims):
      strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
      # Make sure the processeses are in sync after updating the cluster
      multi_process_runner.get_barrier().wait()
      tensor_shape = [2] * num_dims

      def variable_fn():
        with ops.device(self._local_device):
          # The initial value will be broadcasted from worker 0 to others.
          initial_value = (array_ops.ones(tensor_shape) if worker_id == 0 else
                           array_ops.zeros(tensor_shape))
          var = variable_scope.get_variable(name='x', initializer=initial_value)
          return array_ops.identity(var)

      t_out = strategy.extended.call_for_each_replica(variable_fn)
      expected_out = np.ones(tensor_shape)
      self.assertAllClose(t_out, expected_out)
    def worker_step_fn():
      strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
      tf_config = json.loads(os.environ['TF_CONFIG'])
      worker_id = tf_config['task']['index']

      @def_function.function
      def run_reduce():
        with ops.device(local_device):
          t_in = array_ops.ones(tensor_shape) * worker_id
          return strategy.reduce(reduce_util.ReduceOp.MEAN, t_in, axis=None)

      t_out = run_reduce()
      # Element values from the workers are
      #     0, 1, ..., (num_workers - 1)
      expected_mean = (num_workers - 1) / 2
      expected_out = np.ones(tensor_shape) * expected_mean
      self.assertAllClose(t_out, expected_out)
예제 #14
0
  def test_complete_flow_standalone_client_collective_nccl(
      self, eval_distribute_class):
    train_distribute = (
        collective_all_reduce_strategy.CollectiveAllReduceStrategy(
            communication=cross_device_ops_lib.CollectiveCommunication.NCCL))

    if eval_distribute_class:
      eval_distribute = self._get_strategy_object(
          eval_distribute_class, eval_strategy=True)
    else:
      eval_distribute = None

    cluster_spec = copy.deepcopy(self._cluster_spec)
    cluster_spec.pop("ps", None)
    estimator = self._complete_flow(train_distribute, eval_distribute,
                                    cluster_spec)
    self._inspect_train_and_eval_events(estimator)
예제 #15
0
def _model_setup():
  """Set up a MNIST Keras model for testing purposes.

  Builds a MNIST Keras model and returns model information.

  Returns:
    A tuple of (batch_size, steps, train_dataset, mode)
  """
  context.set_log_device_placement(True)
  batch_size = 64
  steps = 2
  with collective_strategy.CollectiveAllReduceStrategy().scope():
    # TODO(b/142509827): In rare cases this errors out at C++ level with the
    # "Connect failed" error message.
    train_ds, _ = mnist_testing_utils.mnist_synthetic_dataset(batch_size, steps)
    model = mnist_testing_utils.get_mnist_model((28, 28, 1))
  return batch_size, steps, train_ds, model
 def testKeepLogicalDevice(self):
   # Cannot change logical device after the context initialization.
   context._reset_context()  # pylint: disable=protected-access
   cluster_spec = multi_worker_test_base.create_cluster_spec(
       has_chief=False, num_workers=1)
   resolver = cluster_resolver_lib.SimpleClusterResolver(
       cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
       task_type='worker',
       task_id=0)
   gpus = tf_config.list_physical_devices('GPU')
   tf_config.set_logical_device_configuration(gpus[-1], [
       context.LogicalDeviceConfiguration(64),
       context.LogicalDeviceConfiguration(64),
   ])
   collective_all_reduce_strategy.CollectiveAllReduceStrategy(
       cluster_resolver=resolver)
   # Since we create two logical GPUs out of the last GPU, there should be one
   # more logical GPUs than physical GPUs.
   self.assertLen(tf_config.list_logical_devices('GPU'), len(gpus) + 1)
   context._reset_context()  # pylint: disable=protected-access
예제 #17
0
def main(_):
    if flags.FLAGS.enable_eager:
        ops.enable_eager_execution()
        logging.info('Eager execution enabled for MNIST Multi-Worker.')
    else:
        logging.info('Eager execution not enabled for MNIST Multi-Worker.')

    # Build the train and eval datasets from the MNIST data.
    train_ds, eval_ds = get_input_datasets()

    if flags.FLAGS.distribution_strategy == 'multi_worker_mirrored':
        # MultiWorkerMirroredStrategy for multi-worker distributed MNIST training.
        strategy = collective_strategy.CollectiveAllReduceStrategy()
    else:
        raise ValueError(
            'Only `multi_worker_mirrored` is supported strategy '
            'in Keras MNIST example at this time. Strategy passed '
            'in is %s' % flags.FLAGS.distribution_strategy)

    # Create and compile the model under Distribution strategy scope.
    # `fit`, `evaluate` and `predict` will be distributed based on the strategy
    # model was compiled with.
    with strategy.scope():
        model = get_model()
        optimizer = rmsprop.RMSProp(learning_rate=0.001)
        model.compile(loss=keras.losses.categorical_crossentropy,
                      optimizer=optimizer,
                      metrics=['accuracy'])

    # Train the model with the train dataset.
    tensorboard_callback = keras.callbacks.TensorBoard(
        log_dir=flags.FLAGS.model_dir)
    model.fit(x=train_ds,
              epochs=20,
              steps_per_epoch=468,
              callbacks=[tensorboard_callback])

    # Evaluate the model with the eval dataset.
    score = model.evaluate(eval_ds, steps=10, verbose=0)
    logging.info('Test loss:{}'.format(score[0]))
    logging.info('Test accuracy:{}'.format(score[1]))
예제 #18
0
    def proc_func():
      global_batch_size = per_worker_batch_size * num_workers
      strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
      with strategy.scope():
        multi_worker_model = build_and_compile_cnn_model()

      callbacks = [
          keras.callbacks.ModelCheckpoint(
              filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
      ]

      multi_worker_dataset = mnist_dataset(global_batch_size)
      if shard_policy:
        options = dataset_ops.Options()
        options.experimental_distribute.auto_shard_policy = shard_policy
        multi_worker_dataset = multi_worker_dataset.with_options(options)

      multi_worker_model.fit(
          multi_worker_dataset,
          epochs=3,
          steps_per_epoch=70,
          callbacks=callbacks)
 def _create_multi_worker_mirrored():
     tf_config = cluster_resolver.TFConfigClusterResolver()
     resolver = cluster_resolver.SimpleClusterResolver(
         cluster_spec=tf_config.cluster_spec(),
         task_type=tf_config.task_type,
         task_id=tf_config.task_id,
         environment=tf_config.environment,
         num_accelerators={"GPU": required_gpus},
         rpc_layer=tf_config.rpc_layer,
     )
     strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
         cluster_resolver=resolver)
     # TODO(b/152320929): Wait for the cluster before proceeding, otherwise
     # collectives may hang if any worker launches collectives before the chief
     # creates the strategy.
     try:
         multi_process_runner.barrier().wait()
     except ValueError:
         # If the creator is called in the main process,
         # multi_process_runner.barrier() raises ValueError, which is safe to
         # ignore.
         pass
     return strategy
예제 #20
0
 def _create_multi_worker_mirrored():
     tf_config = cluster_resolver.TFConfigClusterResolver()
     master = tf_config.master()
     if tf_config.rpc_layer:
         # Strip off the rpc_layer suffix.
         master = master[len("%s://" % tf_config.rpc_layer):]
     resolver = cluster_resolver.SimpleClusterResolver(
         cluster_spec=tf_config.cluster_spec(),
         task_type=tf_config.task_type,
         task_id=tf_config.task_id,
         master=master,
         environment=tf_config.environment,
         num_accelerators={"GPU": required_gpus},
         rpc_layer=tf_config.rpc_layer or "grpc",
     )
     # Disable health check. We don't have a reliable to shutdown the strategy
     # (and thus the health check) at the end of a test. Turning on health check
     # causes some flakiness since we re-create part of the server when creating
     # a strategy, and our tests are capable of handling failures.
     CollectiveAllReduceExtended._enable_check_health = False  # pylint: disable=protected-access
     # Always create the strategy in eager mode so that it starts the server and
     # configures the eager context. The eager context can no longer be
     # configured after initialization.
     with context.eager_mode():
         strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
             cluster_resolver=resolver)
     # TODO(b/152320929): Wait for the cluster before proceeding, otherwise
     # collectives may hang if any worker launches collectives before the chief
     # creates the strategy.
     try:
         multi_process_runner.barrier().wait()
     except ValueError:
         # If the creator is called in the main process,
         # multi_process_runner.barrier() raises ValueError, which is safe to
         # ignore.
         pass
     return strategy
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                maintenance_event=None,
                training_finished=None):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    def mock_request_compute_metadata(*args, **kwargs):
      del kwargs  # Unused.
      if args[0] == 'instance/maintenance-event':
        if (not maintenance_event.is_set()) and (
            strategy.cluster_resolver.task_id
            == 1) and (random.randrange(0, 9) > 6):
          maintenance_event.set()

          logging.info('Maintenance notice available.')
          return 'TERMINATE_ON_HOST_MAINTENANCE'
        else:
          return 'NONE'

      return ''

    with mock.patch.object(gce_util, 'request_compute_metadata',
                           mock_request_compute_metadata), mock.patch.object(
                               gce_util, 'detect_platform',
                               lambda: gce_util.PlatformDevice.GCE_GPU):

      class Model(module.Module):

        def __init__(self):
          self.v = variables_lib.Variable(
              0.,
              synchronization=variables_lib.VariableSynchronization.ON_WRITE,
              aggregation=variables_lib.VariableAggregation.SUM)

        @def_function.function(input_signature=[])
        def __call__(self):
          return self.v.read_value()

      with strategy.scope():
        model = Model()
        fh_ckpt = tracking_util.Checkpoint(model=model)

        failure_handler = failure_handling.CoordinatedCheckpointManager(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Restored training at %d', failure_handler.total_runs)
      for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH,
                         EPOCHS_TO_RUN):

        for step in range(failure_handler.total_runs % STEPS_PER_EPOCH,
                          STEPS_PER_EPOCH):
          failure_handler.run(distributed_train_step, epoch, step)

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)

      training_finished.set()

      pre_del_thread_count = threading.activeCount()
      failure_handler.__del__()
      self.assertLessEqual(threading.activeCount(), pre_del_thread_count - 1)
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  maintenance_event,
                  training_finished,
                  frequent_send=False):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        def mock_request_compute_metadata(*args, **kwargs):
            del kwargs  # Unused.
            if args[0] == 'instance/maintenance-event':
                if not frequent_send:
                    time.sleep(1)
                    if (not maintenance_event.is_set()) and (random.randrange(
                            0, 20) > 18):
                        maintenance_event.set()
                        logging.info('Maintenance notice available.')
                        return 'TERMINATE_ON_HOST_MAINTENANCE'
                elif frequent_send and not maintenance_event.is_set():
                    return 'TERMINATE_ON_HOST_MAINTENANCE'

            return 'NONE'

        with mock.patch.object(
                gce_util, 'request_compute_metadata',
                mock_request_compute_metadata), mock.patch.object(
                    gce_util, 'detect_platform',
                    lambda: gce_util.PlatformDevice.GCE_GPU):

            class Model(module.Module):
                def __init__(self):
                    self.v = variables_lib.Variable(
                        0.,
                        synchronization=variables_lib.VariableSynchronization.
                        ON_WRITE,
                        aggregation=variables_lib.VariableAggregation.SUM)

                @def_function.function(input_signature=[])
                def __call__(self):
                    return self.v.read_value()

            with strategy.scope():
                model = Model()
                fh_ckpt = tracking_util.Checkpoint(model=model)

                failure_handler = failure_handling.CoordinatedCheckpointManager(
                    strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d', failure_handler.total_runs)
            for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH,
                               EPOCHS_TO_RUN):

                for step in range(failure_handler.total_runs % STEPS_PER_EPOCH,
                                  STEPS_PER_EPOCH):
                    failure_handler.run(distributed_train_step, epoch, step)

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)

            training_finished.set()

            running_threads = test_util.get_running_threads()
            strategy.gather(constant_op.constant([10]), axis=0)
            self.assertTrue(
                test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                     running_threads))
            self.assertTrue(
                test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                     running_threads))

            strategy.gather(constant_op.constant([10]), axis=0)

            # Explicitly call __del__ since making it None and gc.collect does
            # not invoke __del__ here.
            failure_handler.__del__()

            time.sleep(2)

            running_threads = test_util.get_running_threads()
            self.assertFalse(
                test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                     running_threads))
            self.assertFalse(
                test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                     running_threads))
예제 #23
0
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                training_started_event=None,
                raise_app_error_on_worker=None):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    class Model(module.Module):

      def __init__(self):
        self.v = variables_lib.Variable(
            0.,
            synchronization=variables_lib.VariableSynchronization.ON_WRITE,
            aggregation=variables_lib.VariableAggregation.SUM)

      @def_function.function(input_signature=[])
      def __call__(self):
        return self.v.read_value()

    with mock.patch.object(gce_util, 'on_gcp', lambda: False):

      with strategy.scope():
        model = Model()
        # Named it fh_ckpt because it'd be better that the user have their
        # regular checkpoint separate from the checkpoint for
        # WorkerPreemptionHandler, since we will create CheckpointManager
        # to manage the checkpoint and only one CheckpointManager should be
        # active in a particular directory at a time.
        fh_ckpt = tracking_util.Checkpoint(model=model)

        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          if distribution_strategy_context.get_distribution_strategy(
          ).cluster_resolver.task_id == raise_app_error_on_worker:
            raise errors_impl.ResourceExhaustedError(
                node_def=None, op=None, message='Running out of resources')

          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Restored training at %d',
                   worker_preemption_watcher.total_runs)
      for epoch in range(
          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
          EPOCHS_TO_RUN):

        for step in range(
            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
            STEPS_PER_EPOCH):
          worker_preemption_watcher.run(distributed_train_step, epoch, step)
        # Add some randomness to when preemption actually happens. We should
        # trigger it for sure if the training is coming to an end and it hasn't
        # been triggered yet.
        if epoch >= EPOCHS_TO_RUN - 2:
          trigger_it = True
        else:
          trigger_it = False

        self._maybe_trigger_a_preemption(training_started_event, trigger_it)

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)
    def worker_fn(
        self,
        checkpoint_dir,
        cluster_spec,
        input_arg,
        maintenance_event=None,
        training_finished=None,
        frequent_send=False,
        training_restarted=None,
        termination_config=failure_handling.TerminationConfig(grace_period=0)):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        def mock_termination_watcher_function_gce(*args, **kwargs):
            del args, kwargs
            if not frequent_send:
                time.sleep(1)
                if (not maintenance_event.is_set()) and (random.randrange(
                        0, 7) == 5):
                    maintenance_event.set()
                    logging.info('Termination notice available.')
                    return True

            elif frequent_send and not maintenance_event.is_set():
                logging.info('Termination notice available.')
                return True

            return False

        with mock.patch.object(
                gce_util, 'termination_watcher_function_gce',
                mock_termination_watcher_function_gce), mock.patch.object(
                    gce_util, 'detect_platform',
                    lambda: gce_util.PlatformDevice.GCE_GPU):

            class Model(module.Module):
                def __init__(self):
                    self.v = variables_lib.Variable(
                        0.,
                        synchronization=variables_lib.VariableSynchronization.
                        ON_WRITE,
                        aggregation=variables_lib.VariableAggregation.SUM)

                @def_function.function(input_signature=[])
                def __call__(self):
                    return self.v.read_value()

            with strategy.scope():
                model = Model()
                fh_ckpt = tracking_util.Checkpoint(model=model)

                if input_arg == 'checkpoint':
                    checkpoint_or_manager = fh_ckpt
                else:
                    checkpoint_or_manager = _make_checkpoint_manager(
                        fh_ckpt, checkpoint_dir, strategy.cluster_resolver)
                preemption_handler = (
                    failure_handling.PreemptionCheckpointHandler(
                        strategy.cluster_resolver, checkpoint_or_manager,
                        checkpoint_dir, termination_config))

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         preemption_handler.total_run_calls)

            # If the training process has been restarted, verify that the expected
            # number of checkpoints have been written.
            # We also want to check training_finished, because there's a corner case
            # where the signal is sent quite late and training finishes before the
            # grace period ends.
            if training_restarted.is_set() and not training_finished.is_set():
                match_group = [
                    re.search(r'.*ckpt-(\d+).index', a_file)
                    for a_file in gfile.ListDirectory(checkpoint_dir)
                ]
                checkpoint_index = [
                    a_match.group(1) for a_match in match_group if a_match
                ]
                if termination_config.grace_period > 0:
                    # Two checkpoints were saved for the extended grace period.
                    self.assertEqual(
                        max([
                            int(ckpt_index) for ckpt_index in checkpoint_index
                        ]), 2)
                else:
                    self.assertEqual(
                        max([
                            int(ckpt_index) for ckpt_index in checkpoint_index
                        ]), 1)

            for epoch in range(
                    preemption_handler.total_run_calls // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        preemption_handler.total_run_calls % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    preemption_handler.run(distributed_train_step, epoch, step)

            logging.info('Training finished.')
            training_finished.set()

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)

            running_threads = test_util.get_running_threads()
            if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                    running_threads) and test_util.has_thread(
                                        _LOCAL_WATCHER_THREAD_PREFIX,
                                        running_threads):
                try:
                    # Explicitly call __del__ since making it None and gc.collect does
                    # not invoke __del__ here.
                    preemption_handler.__del__()

                    time.sleep(2)

                    running_threads = test_util.get_running_threads()
                    self.assertFalse(
                        test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                             running_threads))
                    self.assertFalse(
                        test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                             running_threads))

                except urllib.error.URLError as e:
                    if 'Temporary failure in name resolution' in e.message:
                        # This is caused by a weird flakiness that mock.patch does not
                        # correctly patch gce_util.request_compute_metadata, a real request
                        # is attempted, and an error is hit in
                        # gce_util.request_compute_metadata
                        logging.warning('Hit a mock issue.')
                        return
    def _get_test_objects(self,
                          task_type,
                          task_id,
                          num_gpus=0,
                          communication=CollectiveCommunication.AUTO,
                          use_strategy_object=False,
                          local_mode=False):
        collective_keys = cross_device_utils.CollectiveKeys(
            group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
            op_instance_key_start=100 +
            CollectiveAllReduceTest.collective_key_base,
            variable_instance_key_start=10000 +
            CollectiveAllReduceTest.collective_key_base)
        if local_mode:
            if num_gpus:
                devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
            else:
                devices = ["/device:CPU:0"]

            if use_strategy_object:
                strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
                    communication=communication)
                strategy.extended._collective_keys = collective_keys
                strategy.extended._cross_device_ops._collective_keys = collective_keys
                return strategy, devices, ""
            else:
                collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
                    1,
                    num_gpus,
                    collective_keys=collective_keys,
                    communication=communication)
                return collective_all_reduce_ops, devices, ""
        else:
            # NCCL requires physical GPUs for every replica, which we can't do with
            # simulated multi host set up now.
            assert communication != CollectiveCommunication.NCCL
            if num_gpus:
                devices = [
                    "/job:%s/task:%d/replica:0/device:GPU:%d" %
                    (task_type, task_id, i) for i in range(num_gpus)
                ]
            else:
                devices = [
                    "/job:%s/task:%d/replica:0/device:CPU:0" %
                    (task_type, task_id)
                ]

            if use_strategy_object:
                strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
                    communication=communication)
                strategy.configure(cluster_spec=self._cluster_spec,
                                   task_type=task_type,
                                   task_id=task_id)
                strategy.extended._collective_keys = collective_keys
                strategy.extended._cross_device_ops._collective_keys = collective_keys
                return (strategy, devices,
                        "grpc://" + self._cluster_spec[task_type][task_id])
            else:
                collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
                    NUM_WORKERS,
                    num_gpus,
                    collective_keys=collective_keys,
                    communication=communication)
                return (collective_all_reduce_ops, devices,
                        "grpc://" + self._cluster_spec[task_type][task_id])
예제 #26
0
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  input_arg='checkpoint',
                  training_started_event=None,
                  raise_app_error_on_worker=None,
                  training_restarted=None,
                  training_finished=None,
                  termination_config=failure_handling.TerminationConfig()):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        class Model(module.Module):
            def __init__(self):
                self.v = variables_lib.Variable(
                    0.,
                    synchronization=variables_lib.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variables_lib.VariableAggregation.SUM)

            @def_function.function(input_signature=[])
            def __call__(self):
                return self.v.read_value()

        with mock.patch.object(gce_util, 'on_gcp', lambda: False):

            with strategy.scope():
                model = Model()
                # Named it fh_ckpt because it'd be better that the user have their
                # regular checkpoint separate from the checkpoint for
                # PreemptionCheckpointHandler, since we will create CheckpointManager
                # to manage the checkpoint and only one CheckpointManager should be
                # active in a particular directory at a time.
                fh_ckpt = tracking_util.Checkpoint(model=model)
                if input_arg == 'checkpoint':
                    checkpoint_or_manager = fh_ckpt
                else:
                    checkpoint_or_manager = _make_checkpoint_manager(
                        fh_ckpt, checkpoint_dir, strategy.cluster_resolver)
                preemption_handler = (
                    failure_handling.PreemptionCheckpointHandler(
                        strategy.cluster_resolver, checkpoint_or_manager,
                        checkpoint_dir, termination_config))

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    if distribution_strategy_context.get_distribution_strategy(
                    ).cluster_resolver.task_id == raise_app_error_on_worker:
                        raise errors_impl.ResourceExhaustedError(
                            node_def=None,
                            op=None,
                            message='Running out of resources')

                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         preemption_handler.total_run_calls)

            # If the training process has been restarted, verify that the expected
            # number of checkpoints have been written.
            # we also want to check training_finished, because there's a corner case
            # where the signal is sent quite late and training finishes before the
            # grace period ends.
            if training_restarted and training_restarted.is_set(
            ) and not training_finished.is_set():
                logging.info('training restarted')
                match_group = [
                    re.search(r'.*ckpt-(\d+).index', a_file)
                    for a_file in gfile.ListDirectory(checkpoint_dir)
                ]
                checkpoint_index = [
                    a_match.group(1) for a_match in match_group if a_match
                ]
                if getattr(termination_config, 'grace_period', 0):
                    # Two checkpoints were saved for the extended grace period.
                    self.assertEqual(int(checkpoint_index[0]), 2)
                else:
                    self.assertEqual(int(checkpoint_index[0]), 1)

            for epoch in range(
                    preemption_handler.total_run_calls // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        preemption_handler.total_run_calls % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    preemption_handler.run(distributed_train_step, epoch, step)
                # Add some randomness to when preemption actually happens. We should
                # trigger it for sure if the training is coming to an end and it hasn't
                # been triggered yet.
                if epoch >= EPOCHS_TO_RUN - 2:
                    trigger_it = True
                else:
                    trigger_it = False

                self._maybe_trigger_a_preemption(training_started_event,
                                                 trigger_it)

            training_finished.set()

            logging.info('Training finished.')

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)
예제 #27
0
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                maintenance_event=None,
                training_finished=None,
                frequent_send=False):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    def mock_termination_watcher_function_gce(*args, **kwargs):
      del args, kwargs
      if not frequent_send:
        time.sleep(1)
        if (not maintenance_event.is_set()) and (random.randrange(0, 20) > 18):
          maintenance_event.set()
          logging.info('Termination notice available.')
          return True

      elif frequent_send and not maintenance_event.is_set():
        logging.info('Termination notice available.')
        return True

      return False

    with mock.patch.object(
        gce_util, 'termination_watcher_function_gce',
        mock_termination_watcher_function_gce), mock.patch.object(
            gce_util, 'detect_platform',
            lambda: gce_util.PlatformDevice.GCE_GPU):

      class Model(module.Module):

        def __init__(self):
          self.v = variables_lib.Variable(
              0.,
              synchronization=variables_lib.VariableSynchronization.ON_WRITE,
              aggregation=variables_lib.VariableAggregation.SUM)

        @def_function.function(input_signature=[])
        def __call__(self):
          return self.v.read_value()

      with strategy.scope():
        model = Model()
        fh_ckpt = tracking_util.Checkpoint(model=model)

        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Start training at %d', worker_preemption_watcher.total_runs)
      for epoch in range(
          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
          EPOCHS_TO_RUN):

        for step in range(
            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
            STEPS_PER_EPOCH):
          worker_preemption_watcher.run(distributed_train_step, epoch, step)

      training_finished.set()

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)

      running_threads = test_util.get_running_threads()
      if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                              running_threads) and test_util.has_thread(
                                  _LOCAL_WATCHER_THREAD_PREFIX,
                                  running_threads):
        try:
          # Explicitly call __del__ since making it None and gc.collect does
          # not invoke __del__ here.
          worker_preemption_watcher.__del__()

          time.sleep(2)

          running_threads = test_util.get_running_threads()
          self.assertFalse(
              test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                   running_threads))
          self.assertFalse(
              test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                   running_threads))

        except urllib.error.URLError as e:
          if 'Temporary failure in name resolution' in e.message:
            # This is caused by a weird flakiness that mock.patch does not
            # correctly patch gce_util.request_compute_metadata, a real request
            # is attempted, and an error is hit in
            # gce_util.request_compute_metadata
            logging.warning('Hit a mock issue.')
            return
        def fn(model_path, checkpoint_dir):
            global_batch_size = per_worker_batch_size * num_workers
            strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
            )
            with strategy.scope():
                multi_worker_model = build_and_compile_cnn_model()

            callbacks = [
                keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
            ]

            multi_worker_dataset = mnist_dataset(global_batch_size)
            if shard_policy:
                options = dataset_ops.Options()
                options.experimental_distribute.auto_shard_policy = shard_policy
                multi_worker_dataset = multi_worker_dataset.with_options(
                    options)

            multi_worker_model.fit(multi_worker_dataset,
                                   epochs=2,
                                   steps_per_epoch=20,
                                   callbacks=callbacks)

            def _is_chief(task_type, task_id):
                return task_type is None or task_type == 'chief' or (
                    task_type == 'worker' and task_id == 0)

            def _get_temp_dir(dirpath, task_id):
                base_dirpath = 'workertemp_' + str(task_id)
                temp_dir = os.path.join(dirpath, base_dirpath)
                file_io.recursive_create_dir_v2(temp_dir)
                return temp_dir

            def write_filepath(filepath, task_type, task_id):
                dirpath = os.path.dirname(filepath)
                base = os.path.basename(filepath)
                if not _is_chief(task_type, task_id):
                    dirpath = _get_temp_dir(dirpath, task_id)
                return os.path.join(dirpath, base)

            task_type, task_id = (strategy.cluster_resolver.task_type,
                                  strategy.cluster_resolver.task_id)
            write_model_path = write_filepath(model_path, task_type, task_id)

            multi_worker_model.save(write_model_path)
            if not _is_chief(task_type, task_id):
                file_io.delete_recursively_v2(
                    os.path.dirname(write_model_path))

            # Make sure chief finishes saving before non-chief's assertions.
            multi_process_runner.get_barrier().wait()

            if not file_io.file_exists_v2(model_path):
                raise RuntimeError()
            if file_io.file_exists_v2(write_model_path) != _is_chief(
                    task_type, task_id):
                raise RuntimeError()

            loaded_model = keras.saving.save.load_model(model_path)
            loaded_model.fit(multi_worker_dataset,
                             epochs=2,
                             steps_per_epoch=20)

            checkpoint = tracking_util.Checkpoint(model=multi_worker_model)
            write_checkpoint_dir = write_filepath(checkpoint_dir, task_type,
                                                  task_id)
            checkpoint_manager = checkpoint_management.CheckpointManager(
                checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

            checkpoint_manager.save()
            if not _is_chief(task_type, task_id):
                file_io.delete_recursively_v2(write_checkpoint_dir)

            # Make sure chief finishes saving before non-chief's assertions.
            multi_process_runner.get_barrier().wait()

            if not file_io.file_exists_v2(checkpoint_dir):
                raise RuntimeError()
            if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief(
                    task_type, task_id):
                raise RuntimeError()

            latest_checkpoint = checkpoint_management.latest_checkpoint(
                checkpoint_dir)
            checkpoint.restore(latest_checkpoint)
            multi_worker_model.fit(multi_worker_dataset,
                                   epochs=2,
                                   steps_per_epoch=20)

            logging.info('testMultiWorkerTutorial successfully ends')