예제 #1
0
    def serialized(self):
        """Returns the serialized form of the topology."""
        if self._serialized is None:
            proto = topology_pb2.TopologyProto()
            proto.mesh_shape[:] = list(self._mesh_shape)
            proto.num_tasks = self._device_coordinates.shape[0]
            proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
            proto.device_coordinates = list(self._device_coordinates.flatten())
            self._serialized = proto.SerializeToString()

        return self._serialized
예제 #2
0
    def _parse_topology(self, serialized):
        """Parses a serialized `TopologyProto` into `self`."""
        proto = topology_pb2.TopologyProto()
        proto.ParseFromString(serialized)

        self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
        if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
            raise ValueError(
                "`mesh_shape` must be a vector of size 3 with positive "
                "entries; got {}".format(self._mesh_shape))

        if proto.num_tasks < 0:
            raise ValueError("`num_tasks` must be >= 0; got {}".format(
                proto.num_tasks))
        if proto.num_tpu_devices_per_task < 0:
            raise ValueError(
                "`num_tpu_devices_per_task` must be >= 0; got {}".format(
                    proto.num_tpu_devices_per_task))

        expected_coordinates_size = (proto.num_tasks *
                                     proto.num_tpu_devices_per_task *
                                     len(proto.mesh_shape))
        if len(proto.device_coordinates) != expected_coordinates_size:
            raise ValueError(
                "`device_coordinates` must have shape num_tasks ({}) * "
                "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
                "got shape {}".format(proto.num_tasks,
                                      proto.num_tpu_devices_per_task,
                                      proto.mesh_shape,
                                      len(proto.device_coordinates)))

        coords = np.array(proto.device_coordinates, dtype=np.int32)
        if any(coords < 0):
            raise ValueError("`device_coordinates` must be >= 0")
        coords = coords.reshape(
            (proto.num_tasks, proto.num_tpu_devices_per_task,
             len(proto.mesh_shape)))
        self._device_coordinates = coords