Ejemplo n.º 1
0
  def _experimental_initialize_system(self):
    """Experimental method added to be used by Estimator.

    This is a private method only to be used by Estimator. Other frameworks
    should directly be calling `tf.contrib.distribute.initialize_tpu_system`
    """
    tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
Ejemplo n.º 2
0
  def _create_tpu_strategy():
    resolver = tpu_cluster_resolver.TPUClusterResolver("")
    topology = tpu_strategy_util.initialize_tpu_system(resolver)
    device_assignment = None
    if use_single_core:
      device_assignment = device_assignment_lib.DeviceAssignment(
          topology, core_assignment=device_assignment_lib.
          SINGLE_CORE_ASSIGNMENT)

    strategy = tpu_lib.TPUStrategy(resolver, steps_per_run=steps_per_run,
                                   device_assignment=device_assignment,
                                   **kwargs)
    return strategy
Ejemplo n.º 3
0
def get_tpu_strategy():
    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    tpu_strategy_util.initialize_tpu_system(resolver)
    strategy = tpu_lib.TPUStrategyV2(resolver)
    return strategy
 def _get_strategy(self):
     self.resolver = tpu_cluster_resolver.TPUClusterResolver(
         tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
     remote.connect_to_cluster(self.resolver)
     tpu_strategy_util.initialize_tpu_system(self.resolver)
     return tpu_strategy.TPUStrategy(self.resolver)
Ejemplo n.º 5
0
 def test_cluster_resolver_available(self, enable_packed_var):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
     tpu_strategy_util.initialize_tpu_system(resolver)
     strategy = tpu_lib.TPUStrategy(resolver)
     self.assertIs(strategy.cluster_resolver, resolver)
Ejemplo n.º 6
0
    def _initialize_multi_worker(self, cluster_resolver):
        """Initializes the object for multi-worker training."""
        cluster_spec = multi_worker_util.normalize_cluster_spec(
            cluster_resolver.cluster_spec())
        task_type = cluster_resolver.task_type
        task_id = cluster_resolver.task_id
        if task_type is None or task_id is None:
            raise ValueError(
                "When `cluster_spec` is given, you must also specify "
                "`task_type` and `task_id`.")
        self._cluster_spec = cluster_spec
        self._task_type = task_type
        self._task_id = task_id
        self._id_in_cluster = multi_worker_util.id_in_cluster(
            self._cluster_spec, self._task_type, self._task_id)

        self._num_workers = multi_worker_util.worker_count(
            cluster_spec, task_type)
        if not self._num_workers:
            raise ValueError(
                "No `worker`, `chief` or `evaluator` tasks can be found "
                "in `cluster_spec`.")

        self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
                                                    task_id)

        self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
        self._host_input_device = numpy_dataset.SingleDevice(
            self._worker_device)

        if (ops.executing_eagerly_outside_functions() and
                not getattr(self, "_local_or_standalone_client_mode", False)):
            context.context().configure_collective_ops(
                collective_leader=multi_worker_util.collective_leader(
                    cluster_spec, task_type, task_id),
                scoped_allocator_enabled_ops=("CollectiveReduce", ),
                device_filters=("/job:%s/task:%d" % (task_type, task_id), ))
            self._collective_ops_configured = True
            if context.context().coordination_service is None:
                coordinated_jobs = ["chief", "worker"]
                if task_type in coordinated_jobs:
                    context.context().configure_coordination_service(
                        service_type="standalone",
                        service_leader=multi_worker_util.coordination_leader(
                            cluster_spec),
                        coordinated_jobs=coordinated_jobs)

        # Starting a std server in eager mode and in independent worker mode.
        if (context.executing_eagerly()
                and not getattr(self, "_std_server_started", False) and
                not getattr(self, "_local_or_standalone_client_mode", False)):
            # Checking _local_or_standalone_client_mode as well because we should not
            # create the std server in standalone client mode.
            config_proto = copy.deepcopy(context.context().config)
            config_proto = self._update_config_proto(config_proto)

            # If coordination service is enabled, use its internal heartbeat to detect
            # peer failures instead of the Python-level health check.
            if config_proto.experimental.coordination_config.service_type:
                self._enable_check_health = False

            if hasattr(cluster_resolver, "port"):
                port = cluster_resolver.port
            else:
                port = 0
            server_def = tensorflow_server_pb2.ServerDef(
                cluster=cluster_spec.as_cluster_def(),
                default_session_config=config_proto,
                job_name=task_type,
                task_index=task_id,
                protocol=cluster_resolver.rpc_layer or "grpc",
                port=port)
            context.context().enable_collective_ops(server_def)
            self._std_server_started = True
            # The `ensure_initialized` is needed before calling
            # `context.context().devices()`.
            context.context().ensure_initialized()
            logging.info(
                "Enabled multi-worker collective ops with available devices: %r",
                context.context().devices())

        # TODO(yuefengz): The `num_gpus` is only for this particular task. It
        # assumes all workers have the same number of GPUs. We should remove this
        # assumption by querying all tasks for their numbers of GPUs.
        # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
        # some cases.
        local_devices, local_device_type = self._initialize_local_devices(
            cluster_resolver, self._worker_device)
        if local_device_type == "TPU":
            tpu_strategy_util.initialize_tpu_system()

        self._collective_keys = cross_device_utils.CollectiveKeys(
            group_key_start=1 + self._collective_key_base)
        self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
            devices=local_devices,
            group_size=len(local_devices) * self._num_workers,
            options=self._communication_options,
            collective_keys=self._collective_keys)
        # CrossDeviceOps for per host tensors.
        self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
            devices=[self._worker_device],
            group_size=self._num_workers,
            options=self._communication_options,
            collective_keys=self._collective_keys)
        super(CollectiveAllReduceExtended,
              self)._initialize_single_worker(local_devices)

        # Add a default device so that ops without specified devices will not end up
        # on other workers.
        self._default_device = "/job:%s/task:%d" % (task_type, task_id)

        # Save the num_devices_per_worker and rpc_layer for configure method.
        self._num_devices_per_worker = len(local_devices)
        self._local_device_type = local_device_type
        self._rpc_layer = cluster_resolver.rpc_layer
        self._warn_nccl_no_gpu()

        if self._enable_check_health and context.executing_eagerly():
            self._start_check_health_thread()
        else:
            logging.info("Check health not enabled.")

        logging.info(
            "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
            "task_id = %r, num_workers = %r, local_devices = %r, "
            "communication = %s", cluster_spec.as_dict(), task_type, task_id,
            self._num_workers, local_devices,
            self._communication_options.implementation)
 def test_tpu_initialization(self):
     resolver = tpu_cluster_resolver.TPUClusterResolver('')
     tpu_strategy_util.initialize_tpu_system(resolver)
  def test_checkpoint_save_retrieves(self):
    strategy = self._get_strategy()
    num_rows = strategy.num_replicas_in_sync

    with strategy.scope():
      first_mid_level_contents = np.ones((num_rows, 4))
      first_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(first_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)

      first_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, first_mid_level_optimizer)
      first_mid_level.build(64)

    # Ensure that the variables from the first model are loaded.
    first_mid_level._load_variables()

    self.assertAllClose(
        first_mid_level_contents,
        self.make_checkpoint_and_get_embedding('before_load', first_mid_level,
                                               num_rows),
        msg='Checkpoint should contain values from the first api object.')

    # Reinitialize the tpu.
    tpu_strategy_util.initialize_tpu_system(self.resolver)

    with strategy.scope():
      second_mid_level_contents = np.ones((num_rows, 4)) * 2
      second_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(second_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)
      second_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, second_mid_level_optimizer)
      second_mid_level.build(64)

    second_mid_level._load_variables()

    # When we load the variables from the second mid level API object to the TPU
    # we expect that checkpointing the first mid level API object will now
    # retrieve the values from the TPU which are now different from the current
    # variables in the first mid level.
    self.assertAllClose(
        second_mid_level_contents,
        self.make_checkpoint_and_get_embedding('after_load', first_mid_level,
                                               num_rows),
        msg='Checkpoint should contain values from the second api object.')
  def test_checkpoint_restore_loads(self):
    strategy = self._get_strategy()
    num_rows = strategy.num_replicas_in_sync

    def get_values(mid):
      return ops.convert_to_tensor(
          mid._variables['table']['parameters'].variables[0])

    with strategy.scope():
      first_mid_level_contents = np.ones((num_rows, 4))
      first_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(first_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)

      first_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, first_mid_level_optimizer)
      first_mid_level.build(64)

    first_mid_level._load_variables()

    first_checkpoint = util.Checkpoint(model=first_mid_level)
    first_checkpoint.save(self._get_tmpdir('restore', 'save'))

    tpu_strategy_util.initialize_tpu_system(self.resolver)

    with strategy.scope():
      second_mid_level_contents = np.ones((num_rows, 4)) * 2
      second_mid_level_optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
      initializer = init_ops_v2.Constant(second_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)
      second_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, second_mid_level_optimizer)
      second_mid_level.build(64)

    second_mid_level._load_variables()

    self.assertAllClose(
        second_mid_level_contents,
        get_values(second_mid_level),
        msg='Second mid level api should contain its initial values.',
    )
    # We restore the checkpoint of our first model into our second model.
    # This should load the first mid level API object onto the TPU.
    second_checkpoint = util.Checkpoint(model=second_mid_level)
    second_checkpoint.restore(self._get_tmpdir('restore', 'save-1'))

    # Call retrieve here as a way to check what the TPU contains.
    # Calling the retrieve ops directly might make for a cleaner separation of
    # test and module, though.
    second_mid_level._retrieve_variables()

    self.assertAllClose(
        first_mid_level_contents,
        get_values(second_mid_level),
        msg='Second mid level api should have retrieved the first model values.'
    )