Example #1
0
def load_op_from_signature_def(signature_def, key, import_scope=None):
    """Load an Op from a SignatureDef created by op_signature_def().

  Args:
    signature_def: a SignatureDef proto
    key: string key to op in the SignatureDef outputs.
    import_scope: Scope used to import the op

  Returns:
    Op (or possibly Tensor) in the graph with the same name as saved in the
      SignatureDef.

  Raises:
    NotFoundError: If the op could not be found in the graph.
  """
    tensor_info = signature_def.outputs[key]
    try:
        # The init and train ops are not strictly enforced to be operations, so
        # retrieve any graph element (can be either op or tensor).
        return utils.get_element_from_tensor_info(tensor_info,
                                                  import_scope=import_scope)
    except KeyError:
        raise errors.NotFoundError(
            None, None,
            'The {0} could not be found in the graph. Please make sure the '
            'SavedModel was created by the internal _SavedModelBuilder. If you '
            'are using the public API, please make sure the SignatureDef in the '
            'SavedModel does not contain the key "{0}".'.format(key))
Example #2
0
def list_directory_v2(path):
    """Returns a list of entries contained within a directory.

  The list is in arbitrary order. It does not contain the special entries "."
  and "..".

  Args:
    path: string, path to a directory

  Returns:
    [filename1, filename2, ... filenameN] as strings

  Raises:
    errors.NotFoundError if directory doesn't exist
  """
    if not is_directory(path):
        raise errors.NotFoundError(
            node_def=None,
            op=None,
            message="Could not find directory {}".format(path))
    with errors.raise_exception_on_not_ok_status() as status:
        # Convert each element to string, since the return values of the
        # vector of string should be interpreted as strings, not bytes.
        return [
            compat.as_str_any(filename)
            for filename in pywrap_tensorflow.GetChildren(
                compat.as_bytes(path), status)
        ]
Example #3
0
def _init_from_checkpoint(self, *args, **kwargs):
    """Overrides default init by loading value from checkpoint."""
    self.old_init(*args, **kwargs)
    # pylint: disable=protected-access
    if self._shared_name not in self.ckpt_var_cache:
        raise errors.NotFoundError(
            None, None, "%s not found in checkpoint" % self._shared_name)

    val = self.ckpt_var_cache[self._shared_name]
    if val is not None:
        self.assign(self.ckpt_var_cache[self._shared_name])
        # Avoid assigning for the second time.
        self.ckpt_var_cache[self._shared_name] = None
Example #4
0
def _init_from_checkpoint(self, *args, **kwargs):
  """Overrides default init by loading value from checkpoint."""
  # pylint: disable=protected-access
  self._old_init(*args, **kwargs)
  ckpt_name = self._map_func(self._shared_name)
  if ckpt_name not in self._ckpt_var_cache:
    raise errors.NotFoundError(None, None,
                               "%s not found in checkpoint" % ckpt_name)

  val = self._ckpt_var_cache.get(ckpt_name, None)
  if val is not None:
    self.assign(val)
    # Avoid assigning for the second time.
    self._ckpt_var_cache[ckpt_name] = None
Example #5
0
def list_directory(dirname):
    """Returns a list of entries contained within a directory.

  The list is in arbitrary order. It does not contain the special entries "."
  and "..".

  Args:
    dirname: string, path to a directory

  Returns:
    [filename1, filename2, ... filenameN]

  Raises:
    errors.NotFoundError if directory doesn't exist
  """
    if not is_directory(dirname):
        raise errors.NotFoundError(None, None, 'Could not find directory')
    file_list = get_matching_files(
        os.path.join(compat.as_str_any(dirname), '*'))
    return [
        compat.as_bytes(pywrap_tensorflow.Basename(compat.as_bytes(filename)))
        for filename in file_list
    ]
Example #6
0
def initialize_tpu_system(cluster_resolver=None):
    """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster. If called
    inside tf.function, it returns the serialized topology object instead.

  Raises:
    RuntimeError: If running inside a tf.function.
    NotFoundError: If no TPU devices found in eager mode.
  """
    job = None
    if cluster_resolver is None:
        # If no cluster resolver is specified, and running eagerly, execute the init
        # ops in the current device scope.
        if context.executing_eagerly():
            curr_device = device.DeviceSpec.from_string(
                context.context().device_name)
            if curr_device.job is not None:
                job = "{}/replica:0/task:0".format(curr_device.job)

        cluster_resolver = TPUClusterResolver("")
    assert isinstance(cluster_resolver, TPUClusterResolver)

    tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
    if tpu_name in _INITIALIZED_TPU_SYSTEMS:
        logging.warning(
            "TPU system %s has already been initialized. "
            "Reinitializing the TPU can cause previously created "
            "variables on TPU to be lost.", tpu_name)

    logging.info("Initializing the TPU system: %s", tpu_name)

    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    if tpu_name not in _LOCAL_MASTERS:
        # Explicitly place the tpu.initialize_system in the first worker to
        # avoid the output node match multiple devices error.
        job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

    if context.executing_eagerly():

        @function.defun
        def _tpu_init_fn():
            # In TF1, we usually close chips when compilation fails to clear the data
            # in infeed. In TF2, we don't need to do this because infeed is no longer
            # used, so user can recover from TPU compilation failures more smoothly.
            return tpu.initialize_system(
                job=job, compilation_failure_closes_chips=False)

        # The TPU_SYSTEM device must match the device used in tpu.initialize_system
        # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
        # devices available.
        try:
            with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
                output = _tpu_init_fn()
            context.async_wait()
        except errors.InvalidArgumentError as e:
            raise errors.NotFoundError(
                None, None,
                "TPUs not found in the cluster. Failed in initialization: " +
                str(e))

        # Clear out the eager context caches since the memory is invalid now.
        logging.info("Clearing out eager caches")
        context.context()._clear_caches()  # pylint: disable=protected-access

        serialized_topology = output.numpy()
    elif not ops.executing_eagerly_outside_functions():
        master = cluster_resolver.master()
        cluster_spec = cluster_resolver.cluster_spec()

        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        if cluster_spec:
            session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                serialized_topology = sess.run(tpu.initialize_system())
    else:
        with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
            serialized_topology = tpu.initialize_system(
                job=job, compilation_failure_closes_chips=False)
            # If initialize_tpu_system is called inside tf.function, we only return
            # the serialized topology object as the tf.tpu.Topology object has to be
            # constructed in eager mode.
            return serialized_topology

    logging.info("Finished initializing TPU system.")
    tpu_topology = topology.Topology(serialized=serialized_topology)
    _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

    return tpu_topology
Example #7
0
def dtensor_initialize_tpu_system(enable_coordination_service=False):
  """Initialize the TPU devices.

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.

  Raises:
    RuntimeError: If running inside a tf.function.
    NotFoundError: If no TPU devices found in eager mode.
  """

  assert context.executing_eagerly()
  in_multi_client_mode = api.job_name() != "localhost"

  # Collective GRPC servers are only necessary in mutli-client setup.
  # Single clients (e.g. Forge) can use local mode of collectives.
  if in_multi_client_mode:
    if api.jobs() is None:
      raise ValueError(
          "DTENSOR_JOBS environment variable is required when"
          "using multi-client to properly set up communications between servers"
      )
    multi_client_util.initialize_multi_client_cluster(
        job_name=api.job_name(),
        dtensor_jobs=api.jobs(),
        client_id=api.client_id(),
        collective_leader=api.full_job_name(task_id=0),
        enable_coordination_service=enable_coordination_service)

  # Make sure the server change is fully propagated before attempting to run
  # the core ID merging logic below.
  context.ensure_initialized()
  context.async_wait()
  context.context()._clear_caches()  # pylint: disable=protected-access

  @function.defun
  def _tpu_init_fn():
    return gen_dtensor_ops.configure_and_initialize_global_tpu()

  try:
    with ops.device("/job:" + api.full_job_name() + "/device:TPU_SYSTEM:0"):  # pylint: disable=protected-access
      my_core_ids = _tpu_init_fn()
    logging.info("TPU core IDs: %s", my_core_ids)
    context.initialize_logical_devices()

    # Configure virtual CPUs that is 1:1 mapped to TPU cores.
    context.context().set_logical_cpu_devices(
        len(api.local_devices(_TPU_DEVICE_TYPE)),
        tf_device.DeviceSpec(
            job=api.job_name(), replica=0, task=api.client_id()).to_string())

    # `my_core_ids` contains the IDs of TPU cores attached to this host.
    #
    # To generate correct and efficient XLA AllReduce group assignment, we must
    # merge these arrays from all hosts and broadcast the result back to all
    # hosts, so all hosts can use these mappings in their MLIR passes.
    #
    # This is essentially doing what WaitForDistributedTpuOp and
    # SetGlobalTPUArrayOp do, in our multi-client environment.
    task_id = api.client_id()
    num_tasks = api.num_clients()
    num_devices = api.num_global_devices(_TPU_DEVICE_TYPE)
    num_devices_per_task = int(num_devices / num_tasks)

    # Create a one-time use mesh and layout just for merging core IDs.
    mesh = layout_lib.Mesh([_MESH_DIM_X],
                           *_create_device_array((num_devices,),
                                                 _TPU_DEVICE_TYPE,
                                                 api.client_id()))
    layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh)
    device = dtensor_device.DTensorDevice(meshes=[mesh])
    logging.info("TPU core locations: %s",
                 device.tpu_core_ids_to_locations(my_core_ids))

    # At this point, we don't know which cores are attached to other hosts.
    # The core ID mappings in the runtime haven't been set yet.
    #
    # The core ID merging AllReduce below is carefully written so it works
    # without needing correct core mappings to be set in the runtime. We will
    # use this AllReduce's result to set the core ID mappings, and all future
    # user-initiated AllReduces will use the mappings.
    #
    # The runtime is hard-coded to ignore core ID mappings on this AllReduce.
    all_core_ids = np.zeros([num_devices], dtype=np.int32)
    for i in range(len(my_core_ids)):
      all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i]

    # Only one local device gets valid input: 8 local core IDs among
    # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset.
    # The other 7 local devices get zero inputs. All devices on all host
    # participate in one AllReduce, whose result will be core IDs arranged by
    # task-device ordinals.
    all_core_ids = constant_op.constant([all_core_ids])
    zeros = array_ops.zeros_like(all_core_ids)
    all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1)

    with ops.device(device.name):
      all_core_ids = device.pack(all_core_ids, layout)
      all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0])
      unpacked_all_tpu_ids = device.unpack(all_core_ids)

    all_core_ids = list(unpacked_all_tpu_ids[0].numpy())
    logging.info("All TPU core IDs: %s", all_core_ids)

    # Set the default core ID mappings in the runtime for legacy code and tests.
    #
    # Legacy code and tests create TPU meshes directly without using the
    # `create_tpu_mesh` function below. Those meshes have global device IDs
    # equal to TF task-device ordinals. The `all_core_ids` array happens to
    # arrange core IDs by TF task-device ordinals. Using this array on those
    # meshes guarantee correct although inefficient results.
    device.set_tpu_core_ids("", all_core_ids)

    # Remember enough global, immutable information to be able to build any ring
    # we want prescribed by `create_tpu_mesh` in the future.
    global _all_core_ids
    _all_core_ids = all_core_ids

    all_core_locations = device.tpu_core_ids_to_locations(all_core_ids)
    all_core_locations = [
        _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations
    ]
    global _all_core_locations
    _all_core_locations = all_core_locations
    logging.info("All TPU core locations: %s", all_core_locations)

    tpu_topology = _create_tpu_topology(all_core_locations, num_tasks,
                                        num_devices_per_task)
    global _tpu_topology
    _tpu_topology = tpu_topology
    logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape,
                 tpu_topology.device_coordinates)

    global _dtensor_device
    _dtensor_device = device

    context.async_wait()

  except errors.InvalidArgumentError as e:
    raise errors.NotFoundError(
        None, None, "Initialization failed, no valid TPUs found. " + str(e))

  except errors.InternalError as e:
    logging.error("Hit internal error during TPU system initialization. "
                  + "It is likely hareware failure. \nPlease check the error "
                  + "messages above to see whether that's the case. \nIf so, "
                  + "consider to restart the job or try another machine.")
    raise e

  # Optionally exchange heartbeats between workers every minute.
  if in_multi_client_mode and api.heartbeat_enabled():
    logging.info(
        "Starting DTensor heartbeat service exchanging signals every 10 minutes"
    )
    heartbeat.start(period=180)

  # Clear out the eager context caches since the memory is invalid now.
  logging.info("Clearing out eager caches")
  context.context()._clear_caches()  # pylint: disable=protected-access