예제 #1
0
    async def create(cls,
                     node_id: DHTID,
                     bucket_size: int,
                     depth_modulo: int,
                     num_replicas: int,
                     wait_timeout: float,
                     parallel_rpc: Optional[int] = None,
                     cache_size: Optional[int] = None,
                     listen=True,
                     listen_on='0.0.0.0:*',
                     endpoint: Optional[Endpoint] = None,
                     channel_options: Sequence[Tuple[str, Any]] = (),
                     **kwargs) -> DHTProtocol:
        """
        A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
        As a side-effect, DHTProtocol also maintains a routing table as described in
        https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf

        See DHTNode (node.py) for a more detailed description.

        :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
         for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
         Only the call_* methods are meant to be called publicly, e.g. from DHTNode
         Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
        """
        self = cls(_initialized_with_create=True)
        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
        self.wait_timeout, self.channel_options = wait_timeout, tuple(
            channel_options)
        self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(
            maxsize=cache_size)
        self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
        self.rpc_semaphore = asyncio.Semaphore(
            parallel_rpc if parallel_rpc is not None else float('inf'))

        if listen:  # set up server to process incoming rpc requests
            grpc.aio.init_grpc_aio()
            self.server = grpc.aio.server(**kwargs,
                                          options=GRPC_KEEPALIVE_OPTIONS)
            dht_grpc.add_DHTServicer_to_server(self, self.server)

            self.port = self.server.add_insecure_port(listen_on)
            assert self.port != 0, f"Failed to listen to {listen_on}"
            if endpoint is not None and endpoint.endswith('*'):
                endpoint = replace_port(endpoint, self.port)
            self.node_info = dht_pb2.NodeInfo(
                node_id=node_id.to_bytes(),
                rpc_port=self.port,
                endpoint=endpoint
                or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value)
            await self.server.start()
        else:  # not listening to incoming requests, client-only mode
            # note: use empty node_info so peers won't add you to their routing tables
            self.node_info, self.server, self.port = dht_pb2.NodeInfo(
            ), None, None
            if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
                logger.warning(
                    f"DHTProtocol has no server (due to listen=False), listen_on"
                    f"and kwargs have no effect (unused kwargs: {kwargs})")
        return self
예제 #2
0
 def endpoint(self) -> Endpoint:
     if self._averager_endpoint is None:
         self._averager_endpoint = replace_port(
             self.listen_on, self.port if self.port is not None else '*')
         logger.debug(
             f"Assuming averager endpoint to be {self._averager_endpoint}")
     return self._averager_endpoint
예제 #3
0
    async def rpc_ping(self, request: dht_pb2.PingRequest,
                       context: grpc.ServicerContext):
        """ Some node wants us to add it to our routing table. """
        response = dht_pb2.PingResponse(peer=self.node_info,
                                        sender_endpoint=context.peer(),
                                        dht_time=get_dht_time(),
                                        available=False)

        if request.peer and request.peer.node_id and request.peer.rpc_port:
            sender_id = DHTID.from_bytes(request.peer.node_id)
            if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
                sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
            else:
                sender_endpoint = replace_port(context.peer(),
                                               new_port=request.peer.rpc_port)

            response.sender_endpoint = sender_endpoint
            if request.validate:
                response.available = await self.call_ping(
                    response.sender_endpoint, validate=False) == sender_id

            asyncio.create_task(
                self.update_routing_table(sender_id,
                                          sender_endpoint,
                                          responded=response.available
                                          or not request.validate))

        return response
예제 #4
0
    def __init__(self,
                 dht: Optional[DHT],
                 expert_backends: Dict[str, ExpertBackend],
                 listen_on: Endpoint = "0.0.0.0:*",
                 num_connection_handlers: int = 1,
                 update_period: int = 30,
                 start=False,
                 checkpoint_dir=None,
                 **kwargs):
        super().__init__()
        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
        if get_port(listen_on) is None:
            listen_on = replace_port(listen_on, new_port=find_open_port())
        self.listen_on, self.port = listen_on, get_port(listen_on)

        self.conn_handlers = [
            ConnectionHandler(listen_on, self.experts)
            for _ in range(num_connection_handlers)
        ]
        if checkpoint_dir is not None:
            self.checkpoint_saver = CheckpointSaver(expert_backends,
                                                    checkpoint_dir,
                                                    update_period)
        else:
            self.checkpoint_saver = None
        self.runtime = Runtime(self.experts, **kwargs)

        if start:
            self.run_in_background(await_ready=True)
예제 #5
0
 async def rpc_ping(self, peer_info: dht_pb2.NodeInfo,
                    context: grpc.ServicerContext):
     """ Some node wants us to add it to our routing table. """
     if peer_info.node_id and peer_info.rpc_port:
         sender_id = DHTID.from_bytes(peer_info.node_id)
         rpc_endpoint = replace_port(context.peer(),
                                     new_port=peer_info.rpc_port)
         asyncio.create_task(
             self.update_routing_table(sender_id, rpc_endpoint))
     return self.node_info
예제 #6
0
    def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
               compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, *, start: bool,
               **kwargs) -> Server:
        """
        Instantiate a server with several identical experts. See argparse comments below for details
        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
        :param num_experts: run this many identical experts
        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
           means "sample random experts between myprefix.0.0 and myprefix.255.255;
        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
        :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
        :param hidden_dim: main dimension for expert_cls
        :param num_handlers: server will use this many parallel processes to handle incoming requests
        :param max_batch_size: total num examples in the same batch will not exceed this value
        :param device: all experts will use this device in torch notation; default: cuda if available else cpu

        :param optim_cls: uses this optimizer to train all experts
        :param scheduler: if not `none`, the name of the expert LR scheduler
        :param num_warmup_steps: the number of warmup steps for LR schedule
        :param num_total_steps: the total number of steps for LR schedule
        :param clip_grad_norm: maximum gradient norm used for clipping

        :param no_dht: if specified, the server will not be attached to a dht
        :param initial_peers: a list of peers that will introduce this node to the dht,\
           e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers

        :param dht_port:  DHT node will listen on this port, default = find open port
           You can then use this node as initial peer for subsequent servers.

        :param checkpoint_dir: directory to save and load expert checkpoints

        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
            hosted on this server. For a more fine-grained compression, start server in python and specify compression
            for each BatchTensorProto in ExpertBackend for the respective experts.

        :param start: if True, starts server right away and returns when server is ready for requests
        :param stats_report_interval: interval between two reports of batch processing performance statistics
        """
        if len(kwargs) != 0:
            logger.info("Ignored kwargs:", kwargs)
        assert expert_cls in name_to_block

        if no_dht:
            dht = None
        else:
            dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
            dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
            logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")

        assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
                (num_experts is not None and expert_uids is None)), \
            "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"

        if expert_uids is None:
            if checkpoint_dir is not None:
                assert is_directory(checkpoint_dir)
                expert_uids = [child.name for child in checkpoint_dir.iterdir() if
                               (child / 'checkpoint_last.pt').exists()]
                total_experts_in_checkpoint = len(expert_uids)
                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")

                if total_experts_in_checkpoint > num_experts:
                    raise ValueError(
                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints.")
            else:
                expert_uids = []

            uids_to_generate = num_experts - len(expert_uids)
            if uids_to_generate > 0:
                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
                expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))

        num_experts = len(expert_uids)
        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        sample_input = name_to_input[expert_cls](4, hidden_dim)
        if isinstance(sample_input, tuple):
            args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
        else:
            args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)

        scheduler = schedule_name_to_scheduler[scheduler]

        # initialize experts
        experts = {}
        for expert_uid in expert_uids:
            expert = name_to_block[expert_cls](hidden_dim)
            experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
                                                         args_schema=args_schema,
                                                         outputs_schema=hivemind.BatchTensorDescriptor(
                                                             hidden_dim, compression=compression),
                                                         optimizer=optim_cls(expert.parameters()),
                                                         scheduler=scheduler,
                                                         num_warmup_steps=num_warmup_steps,
                                                         num_total_steps=num_total_steps,
                                                         clip_grad_norm=clip_grad_norm,
                                                         max_batch_size=max_batch_size)

        if checkpoint_dir is not None:
            load_experts(experts, checkpoint_dir)

        return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
                   checkpoint_dir=checkpoint_dir, stats_report_interval=stats_report_interval, start=start)
예제 #7
0
    def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
               load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
        """
        Instantiate a server with several identical experts. See argparse comments below for details
        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
        :param num_experts: run this many identical experts
        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
           means "sample random experts between myprefix.0.0 and myprefix.255.255;
        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
        :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
        :param hidden_dim: main dimension for expert_cls
        :param num_handlers: server will use this many parallel processes to handle incoming requests
        :param max_batch_size: total num examples in the same batch will not exceed this value
        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
        :param optim_cls: uses this optimizer to train all experts
        :param no_dht: if specified, the server will not be attached to a dht
        :param initial_peers: a list of peers that will introduce this node to the dht,\
           e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers

        :param dht_port:  DHT node will listen on this port, default = find open port
           You can then use this node as initial peer for subsequent servers.

        :param checkpoint_dir: directory to save expert checkpoints
        :param load_experts: whether to load expert checkpoints from checkpoint_dir

        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
            hosted on this server. For a more fine-grained compression, start server in python and specify compression
            for each BatchTensorProto in ExpertBackend for the respective experts.

        :param start: if True, starts server right away and returns when server is ready for requests
        """
        if len(kwargs) != 0:
            logger.info("Ignored kwargs:", kwargs)
        assert expert_cls in name_to_block

        if no_dht:
            dht = None
        else:
            dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
            dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
            logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")

        if load_experts:
            assert dir_is_correct(checkpoint_dir)
            assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
            expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
            if expert_uids:
                logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
            else:
                logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")

        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"

        # get expert uids if not loaded previously
        if expert_uids is None:
            assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
            logger.info(f"Generating expert uids from pattern {expert_pattern}")
            expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)

        num_experts = len(expert_uids)
        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        sample_input = name_to_input[expert_cls](4, hidden_dim)
        if isinstance(sample_input, tuple):
            args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
        else:
            args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)

        # initialize experts
        experts = {}
        for expert_uid in expert_uids:
            expert = name_to_block[expert_cls](hidden_dim)
            experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
                                                         args_schema=args_schema,
                                                         outputs_schema=hivemind.BatchTensorDescriptor(
                                                             hidden_dim, compression=compression),
                                                         opt=optim_cls(expert.parameters()),
                                                         max_batch_size=max_batch_size)

        if load_experts:
            load_weights(experts, checkpoint_dir)

        server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
                        start=start)
        return server