Пример #1
0
def _parse_topology(self, serialized):
    """Parses a serialized `TopologyProto` into `self`."""
    proto = topology_pb2.TopologyProto()
    proto.ParseFromString(serialized)

    self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
    if len(self._mesh_shape) != 4 or any(self._mesh_shape < 1):
      return __parse_topology(self, serialized)
      raise ValueError("`mesh_shape` must be a vector of size 4 with positive "
                       "entries; got {}".format(self._mesh_shape))

    if proto.num_tasks < 0:
      raise ValueError("`num_tasks` must be >= 0; got {}".format(
          proto.num_tasks))
    if proto.num_tpu_devices_per_task < 0:
      raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format(
          proto.num_tpu_devices_per_task))

    expected_coordinates_size = (
        proto.num_tasks * proto.num_tpu_devices_per_task * len(
            proto.mesh_shape))
    if len(proto.device_coordinates) != expected_coordinates_size:
      raise ValueError("`device_coordinates` must have shape num_tasks ({}) * "
                       "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
                       "got shape {}".format(proto.num_tasks,
                                             proto.num_tpu_devices_per_task,
                                             proto.mesh_shape,
                                             len(proto.device_coordinates)))

    coords = np.array(proto.device_coordinates, dtype=np.int32)
    if any(coords < 0):
      raise ValueError("`device_coordinates` must be >= 0")
    coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task,
                             len(proto.mesh_shape)))
    self._device_coordinates = coords
Пример #2
0
 def initialize_system(self, sess):
     """Run tpu.initialize_system and return the number of TPU devices."""
     topology_object = topology_pb2.TopologyProto()
     topology = sess.run(tf.tpu.initialize_system())
     topology_object.ParseFromString(topology)
     num_cores = topology_object.num_tasks * (
         topology_object.num_tpu_devices_per_task)
     return topology, num_cores
    def testTpuTopology(self):
        cluster_resolver = resolver.TPUClusterResolver(tpu='local')
        self.assertIsNone(cluster_resolver._tpu_topology)

        # Test set with tpu topology proto.
        cluster_resolver.set_tpu_topology(
            serialized_tpu_topology=topology_pb2.TopologyProto(
                mesh_shape=[1, 1, 1, 1]).SerializeToString())
        self.assertIsInstance(cluster_resolver.tpu_hardware_feature,
                              topology_pb2.TPUHardwareFeature)
Пример #4
0
    def serialized(self):
        """Returns the serialized form of the topology."""
        if self._serialized is None:
            proto = topology_pb2.TopologyProto()
            proto.mesh_shape[:] = list(self._mesh_shape)
            proto.num_tasks = self._device_coordinates.shape[0]
            proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
            proto.device_coordinates.extend(
                list(self._device_coordinates.flatten()))
            self._serialized = proto.SerializeToString()

        return self._serialized
Пример #5
0
    def _init_tpu(self, num_partitions, device_order_mode):
        """Initialize tpu device assignment."""
        tf.logging.info('Initializing TPU to get device assignment: start')

        graph = tf.Graph()
        with graph.as_default():
            init_tpu_op = tf.tpu.initialize_system()
        try:
            sess = tf.Session(target=self._tpu,
                              graph=graph,
                              config=self._no_opt_sess_cfg())
            topology = sess.run(init_tpu_op)
        except Exception as e:
            tf.logging.fatal('TPU initialization failed: %s', e)
            raise

        topology_proto = topology_pb2.TopologyProto()
        topology_proto.ParseFromString(topology)
        tf.logging.info('topology.num_tasks: %r', topology_proto.num_tasks)
        tf.logging.info('topology.num_tpu_devices_per_task: %r',
                        topology_proto.num_tpu_devices_per_task)
        tf.logging.info('topology.mesh_shape: %r', topology_proto.mesh_shape)
        self.cluster_params = self._configure_cluster_params(
            tpu_cores=(topology_proto.num_tpu_devices_per_task *
                       topology_proto.num_tasks),
            cpu_hosts=topology_proto.num_tasks)

        # We assume the topology and device assignment does not change
        # for a single address space.
        device_assignment = tpu_device_assignment.device_assignment(
            topology,
            computation_shape=py_utils.ComputationShape(
                num_partitions, topology),
            num_replicas=1,
            device_order_mode=device_order_mode)
        py_utils.SetTpuDeviceAssignment(device_assignment)

        tf.logging.info('Initializing TPU to get device assignment: done')
Пример #6
0
 def testTpuTopology(self, serialized):
     proto = topology_pb2.TopologyProto()
     proto.ParseFromString(serialized)
     mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
     self.log(pb_to_json(proto))
     if proto.num_tasks < 0:
         raise ValueError("`num_tasks` must be >= 0; got {}".format(
             proto.num_tasks))
     if proto.num_tpu_devices_per_task < 0:
         raise ValueError(
             "`num_tpu_devices_per_task` must be >= 0; got {}".format(
                 proto.num_tpu_devices_per_task))
     expected_coordinates_size = (proto.num_tasks *
                                  proto.num_tpu_devices_per_task *
                                  len(proto.mesh_shape))
     if len(proto.device_coordinates) != expected_coordinates_size:
         raise ValueError(
             "`device_coordinates` must have shape num_tasks ({}) * "
             "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
             "got shape {}".format(proto.num_tasks,
                                   proto.num_tpu_devices_per_task,
                                   proto.mesh_shape,
                                   len(proto.device_coordinates)))
     coords = np.array(proto.device_coordinates, dtype=np.int32)
     if any(coords < 0):
         raise ValueError("`device_coordinates` must be >= 0")
     coords = coords.reshape(
         (proto.num_tasks, proto.num_tpu_devices_per_task,
          len(proto.mesh_shape)))
     self.log(coords)
     if len(proto.device_coordinates) != expected_coordinates_size:
         raise ValueError(
             "`device_coordinates` must have shape num_tasks ({}) * "
             "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
             "got shape {}".format(proto.num_tasks,
                                   proto.num_tpu_devices_per_task,
                                   proto.mesh_shape,
                                   len(proto.device_coordinates)))
Пример #7
0
 def set_tpu_topology(self, serialized_tpu_topology):
     """Sets the tpu topology info stored in this resolver."""
     self._tpu_topology = topology_pb2.TopologyProto()
     self._tpu_topology.ParseFromString(serialized_tpu_topology)