Exemple #1
0
    def as_proto(self) -> layout_pb2.MeshProto:
        """Returns mesh protobuffer."""

        mesh_proto = layout_pb2.MeshProto()

        mesh_proto.name = self._name

        for i, mesh_dimension in enumerate(self._dim_names):
            dim = mesh_proto.mesh_dimensions.add()
            dim.name = mesh_dimension
            dim.size = self._global_device_ids.shape[i]

        for d in np.ravel(self._global_device_ids):
            mesh_proto.global_device_ids.append(d)

        for d in self._local_device_ids:
            mesh_proto.local_device_ids.append(d)

        for d in self._local_devices:
            mesh_proto.local_devices.append(d.to_string())

        if self._global_devices:
            for d in self._global_devices:
                mesh_proto.global_devices.append(d.to_string())

        return mesh_proto
Exemple #2
0
    def from_string(mesh_str: str) -> 'Mesh':
        """Construct a mesh instance from input `proto`."""
        # Separate elements of mesh.
        mesh_parts = mesh_str.split('|')
        global_dev_str = None
        if len(mesh_parts) == 5:
            name, mesh_dim_strs, global_id_str, local_id_str, dev_str = mesh_parts
        elif len(mesh_parts) == 6:
            (name, mesh_dim_strs, global_id_str, local_id_str, dev_str,
             global_dev_str) = mesh_parts
        else:
            raise ValueError('Invalid mesh string : %s' % mesh_str)

        # Load mesh proto.
        mesh_proto = layout_pb2.MeshProto()
        mesh_proto.name = name

        for mesh_dim_str in mesh_dim_strs.split(','):
            name, size_str = mesh_dim_str.split('=')
            dim = mesh_proto.mesh_dimensions.add()
            dim.name = name
            dim.size = int(size_str)

        for global_id in global_id_str.split(','):
            mesh_proto.global_device_ids.append(int(global_id))

        if local_id_str:
            for local_id in local_id_str.split(','):
                mesh_proto.local_device_ids.append(int(local_id))

        if dev_str:
            for dev in dev_str.split(','):
                mesh_proto.local_devices.append(dev)

        if global_dev_str:
            for dev in global_dev_str.split(','):
                mesh_proto.global_devices.append(dev)

        return Mesh.from_proto(mesh_proto)