def as_proto(self) -> layout_pb2.MeshProto: """Returns mesh protobuffer.""" mesh_proto = layout_pb2.MeshProto() mesh_proto.name = self._name for i, mesh_dimension in enumerate(self._dim_names): dim = mesh_proto.mesh_dimensions.add() dim.name = mesh_dimension dim.size = self._global_device_ids.shape[i] for d in np.ravel(self._global_device_ids): mesh_proto.global_device_ids.append(d) for d in self._local_device_ids: mesh_proto.local_device_ids.append(d) for d in self._local_devices: mesh_proto.local_devices.append(d.to_string()) if self._global_devices: for d in self._global_devices: mesh_proto.global_devices.append(d.to_string()) return mesh_proto
def from_string(mesh_str: str) -> 'Mesh': """Construct a mesh instance from input `proto`.""" # Separate elements of mesh. mesh_parts = mesh_str.split('|') global_dev_str = None if len(mesh_parts) == 5: name, mesh_dim_strs, global_id_str, local_id_str, dev_str = mesh_parts elif len(mesh_parts) == 6: (name, mesh_dim_strs, global_id_str, local_id_str, dev_str, global_dev_str) = mesh_parts else: raise ValueError('Invalid mesh string : %s' % mesh_str) # Load mesh proto. mesh_proto = layout_pb2.MeshProto() mesh_proto.name = name for mesh_dim_str in mesh_dim_strs.split(','): name, size_str = mesh_dim_str.split('=') dim = mesh_proto.mesh_dimensions.add() dim.name = name dim.size = int(size_str) for global_id in global_id_str.split(','): mesh_proto.global_device_ids.append(int(global_id)) if local_id_str: for local_id in local_id_str.split(','): mesh_proto.local_device_ids.append(int(local_id)) if dev_str: for dev in dev_str.split(','): mesh_proto.local_devices.append(dev) if global_dev_str: for dev in global_dev_str.split(','): mesh_proto.global_devices.append(dev) return Mesh.from_proto(mesh_proto)