def _parse_topology(self, serialized): """Parses a serialized `TopologyProto` into `self`.""" proto = topology_pb2.TopologyProto() proto.ParseFromString(serialized) self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32) if len(self._mesh_shape) != 4 or any(self._mesh_shape < 1): return __parse_topology(self, serialized) raise ValueError("`mesh_shape` must be a vector of size 4 with positive " "entries; got {}".format(self._mesh_shape)) if proto.num_tasks < 0: raise ValueError("`num_tasks` must be >= 0; got {}".format( proto.num_tasks)) if proto.num_tpu_devices_per_task < 0: raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format( proto.num_tpu_devices_per_task)) expected_coordinates_size = ( proto.num_tasks * proto.num_tpu_devices_per_task * len( proto.mesh_shape)) if len(proto.device_coordinates) != expected_coordinates_size: raise ValueError("`device_coordinates` must have shape num_tasks ({}) * " "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); " "got shape {}".format(proto.num_tasks, proto.num_tpu_devices_per_task, proto.mesh_shape, len(proto.device_coordinates))) coords = np.array(proto.device_coordinates, dtype=np.int32) if any(coords < 0): raise ValueError("`device_coordinates` must be >= 0") coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task, len(proto.mesh_shape))) self._device_coordinates = coords
def initialize_system(self, sess): """Run tpu.initialize_system and return the number of TPU devices.""" topology_object = topology_pb2.TopologyProto() topology = sess.run(tf.tpu.initialize_system()) topology_object.ParseFromString(topology) num_cores = topology_object.num_tasks * ( topology_object.num_tpu_devices_per_task) return topology, num_cores
def testTpuTopology(self): cluster_resolver = resolver.TPUClusterResolver(tpu='local') self.assertIsNone(cluster_resolver._tpu_topology) # Test set with tpu topology proto. cluster_resolver.set_tpu_topology( serialized_tpu_topology=topology_pb2.TopologyProto( mesh_shape=[1, 1, 1, 1]).SerializeToString()) self.assertIsInstance(cluster_resolver.tpu_hardware_feature, topology_pb2.TPUHardwareFeature)
def serialized(self): """Returns the serialized form of the topology.""" if self._serialized is None: proto = topology_pb2.TopologyProto() proto.mesh_shape[:] = list(self._mesh_shape) proto.num_tasks = self._device_coordinates.shape[0] proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] proto.device_coordinates.extend( list(self._device_coordinates.flatten())) self._serialized = proto.SerializeToString() return self._serialized
def _init_tpu(self, num_partitions, device_order_mode): """Initialize tpu device assignment.""" tf.logging.info('Initializing TPU to get device assignment: start') graph = tf.Graph() with graph.as_default(): init_tpu_op = tf.tpu.initialize_system() try: sess = tf.Session(target=self._tpu, graph=graph, config=self._no_opt_sess_cfg()) topology = sess.run(init_tpu_op) except Exception as e: tf.logging.fatal('TPU initialization failed: %s', e) raise topology_proto = topology_pb2.TopologyProto() topology_proto.ParseFromString(topology) tf.logging.info('topology.num_tasks: %r', topology_proto.num_tasks) tf.logging.info('topology.num_tpu_devices_per_task: %r', topology_proto.num_tpu_devices_per_task) tf.logging.info('topology.mesh_shape: %r', topology_proto.mesh_shape) self.cluster_params = self._configure_cluster_params( tpu_cores=(topology_proto.num_tpu_devices_per_task * topology_proto.num_tasks), cpu_hosts=topology_proto.num_tasks) # We assume the topology and device assignment does not change # for a single address space. device_assignment = tpu_device_assignment.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_partitions, topology), num_replicas=1, device_order_mode=device_order_mode) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('Initializing TPU to get device assignment: done')
def testTpuTopology(self, serialized): proto = topology_pb2.TopologyProto() proto.ParseFromString(serialized) mesh_shape = np.array(proto.mesh_shape, dtype=np.int32) self.log(pb_to_json(proto)) if proto.num_tasks < 0: raise ValueError("`num_tasks` must be >= 0; got {}".format( proto.num_tasks)) if proto.num_tpu_devices_per_task < 0: raise ValueError( "`num_tpu_devices_per_task` must be >= 0; got {}".format( proto.num_tpu_devices_per_task)) expected_coordinates_size = (proto.num_tasks * proto.num_tpu_devices_per_task * len(proto.mesh_shape)) if len(proto.device_coordinates) != expected_coordinates_size: raise ValueError( "`device_coordinates` must have shape num_tasks ({}) * " "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); " "got shape {}".format(proto.num_tasks, proto.num_tpu_devices_per_task, proto.mesh_shape, len(proto.device_coordinates))) coords = np.array(proto.device_coordinates, dtype=np.int32) if any(coords < 0): raise ValueError("`device_coordinates` must be >= 0") coords = coords.reshape( (proto.num_tasks, proto.num_tpu_devices_per_task, len(proto.mesh_shape))) self.log(coords) if len(proto.device_coordinates) != expected_coordinates_size: raise ValueError( "`device_coordinates` must have shape num_tasks ({}) * " "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); " "got shape {}".format(proto.num_tasks, proto.num_tpu_devices_per_task, proto.mesh_shape, len(proto.device_coordinates)))
def set_tpu_topology(self, serialized_tpu_topology): """Sets the tpu topology info stored in this resolver.""" self._tpu_topology = topology_pb2.TopologyProto() self._tpu_topology.ParseFromString(serialized_tpu_topology)