async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]): loop = asyncio.get_event_loop() start_time = get_dht_time() group_id = None while not future.done(): try: self._pending_group_assembled.clear() gather_binary = self.serializer.dumps(gather) allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary) if allreduce_group is None: raise AllreduceException("Averaging step failed: could not find a group.") group_id = allreduce_group.group_id self._running_groups[group_id] = allreduce_group self._pending_group_assembled.set() await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout) await loop.run_in_executor(None, self.update_tensors, allreduce_group) # averaging is finished, exit the loop gathered_items = map(self.serializer.loads, allreduce_group.gathered) gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items)) future.set_result(gathered_data_by_peer) except (AllreduceException, MatchmakingException): time_elapsed = get_dht_time() - start_time if not allow_retries or (timeout is not None and timeout < time_elapsed): future.set_result(None) except Exception as e: future.set_exception(e) raise finally: _ = self._running_groups.pop(group_id, None) self._pending_group_assembled.set()
async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]: """ Get peer's node id and add him to the routing table. If peer doesn't respond, return None :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888 :param validate: if True, validates that node's endpoint is available :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table :return: node's DHTID, if peer responded and decided to send his node_id """ try: async with self.rpc_semaphore: ping_request = dht_pb2.PingRequest(peer=self.node_info, validate=validate) time_requested = get_dht_time() response = await self._get_dht_stub(peer).rpc_ping( ping_request, timeout=self.wait_timeout) time_responded = get_dht_time() except grpc.aio.AioRpcError as error: logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}") response = None responded = bool(response and response.peer and response.peer.node_id) if responded and validate: try: if self.server is not None and not response.available: raise ValidationError( f"Peer {peer} couldn't access this node at {response.sender_endpoint} . " f"Make sure that this port is open for incoming requests." ) if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value: if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \ response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS: raise ValidationError( f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds " f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})" ) except ValidationError as e: if strict: raise else: logger.warning(repr(e)) peer_id = DHTID.from_bytes( response.peer.node_id) if responded else None asyncio.create_task( self.update_routing_table(peer_id, peer, responded=responded)) return peer_id
async def notify_stragglers_on_success(self): """ Find averagers that have fewer nbits and redirect them to your current nbits """ for nbits in reversed(range(1, len(self.group_bits) - 1)): preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}" preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None) if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data: await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration) break root_data = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) if root_data is None or self.RESERVED_KEY_FOR_NBITS not in root_data.value: await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
async def notify_stragglers(self): """ Find averagers that have fewer nbits and redirect them to your current nbits """ for nbits in reversed(range(1, len(self.group_bits) - 1)): preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}" preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None) if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data: await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration) break root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None) if isinstance(root_data, dict) and root_data.get( self.RESERVED_KEY_FOR_NBITS, (None, -float('inf')))[1] > get_dht_time() + self.nbits_grace_period: return await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext): """ Some node wants us to add it to our routing table. """ response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(), dht_time=get_dht_time(), available=False) if request.peer and request.peer.node_id and request.peer.rpc_port: sender_id = DHTID.from_bytes(request.peer.node_id) if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value: sender_endpoint = request.peer.endpoint # if peer has preferred endpoint, use it else: sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port) response.sender_endpoint = sender_endpoint if request.validate: response.available = await self.call_ping( response.sender_endpoint, validate=False) == sender_id asyncio.create_task( self.update_routing_table(sender_id, sender_endpoint, responded=response.available or not request.validate)) return response
def _average_parameters_in_background( lock_parameters: Lock, update_event: Event, stop_event: Event, averager: TrainingAverager, averaging_period: float, verbose: bool, **kwargs): """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """ while not stop_event.is_set(): update_event.wait() update_event.clear() if stop_event.is_set(): break if averaging_period: current_time = get_dht_time() # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups) time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period) time.sleep(time_to_nearest_interval) if verbose: logger.info(f"Starting a new averaging round with current parameters.") try: group_info = averager.step(lock_parameters, **kwargs) if verbose: if group_info is not None: logger.info(f"Finished averaging round in with {len(group_info)} peers.") else: logger.warning(f"Averaging round failed: could not find group.") except Exception as e: logger.error(f"Averaging round failed: caught {e}.")
def shutdown(self): logger.debug("Shutting down averager...") self.averager.shutdown() logger.debug("Sending goodbye to peers...") self.dht.store(self.training_progress_key, subkey=self._local_public_key, value=None, expiration_time=get_dht_time() + self.metadata_expiration) logger.debug(f"{self.__class__.__name__} is shut down.")
def __init__(self, lower: int, upper: int, size: int, depth: int = 0): assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth) self.lower, self.upper, self.size, self.depth = lower, upper, size, depth self.nodes_to_endpoint: Dict[DHTID, Endpoint] = {} self.replacement_nodes: Dict[DHTID, Endpoint] = {} self.nodes_requested_for_ping: Set[DHTID] = set() self.last_updated = get_dht_time()
async def _get_initial_beam( dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...], negative_caching: bool, num_workers: Optional[int] = None ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]: num_workers = num_workers or dht.max_workers or beam_size beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = [] unattempted_indices: List[Coordinate] = sorted( range(len(scores)), key=scores.__getitem__) # from worst to best pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque() while len(beam) < beam_size and (unattempted_indices or pending_tasks): # dispatch additional tasks while unattempted_indices and len(pending_tasks) < num_workers: next_index = unattempted_indices.pop( ) # note: this is best unattempted index because of sort order next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}" pending_tasks.append( (next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix)))) # await the next best prefix to be fetched pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft( ) try: maybe_prefix_data = await pending_task if maybe_prefix_data is not None and isinstance( maybe_prefix_data.value, dict): successors = { coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items() if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2 } if successors: beam.append((scores[pending_best_index], pending_best_prefix, successors)) elif maybe_prefix_data is None and negative_caching: logger.debug( f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}" ) asyncio.create_task( node.store(pending_best_prefix, subkey=-1, value=None, expiration_time=get_dht_time() + dht.default_expiration)) except asyncio.CancelledError: for _, pending_task in pending_tasks: pending_task.cancel() raise return beam
async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration) -> Dict[ExpertUID, bool]: num_workers = len(uids) if dht.max_workers is None else min( len(uids), dht.max_workers) expiration_time = get_dht_time() + expiration data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {} for uid in uids: data_to_store[uid, None] = endpoint prefix = uid if uid.count( UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}' for i in range(prefix.count(UID_DELIMITER) - 1): prefix, last_coord = split_uid(prefix) data_to_store[prefix, last_coord] = [uid, endpoint] keys, maybe_subkeys, values = zip( *((key, subkey, value) for (key, subkey), value in data_to_store.items())) store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers) return store_ok
def pause(self): """ While inside this context, EMA will not count the time passed towards the performance estimate """ self.paused, was_paused = True, self.paused try: yield finally: self.timestamp = get_dht_time() self.paused = was_paused
def _average_parameters_in_background( lock_parameters: Lock, update_event: Event, stop_event: Event, averager: DecentralizedAverager, opt: torch.optim.Optimizer, averaging_period: float, verbose: bool, **kwargs): """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """ while not stop_event.is_set(): update_event.wait() update_event.clear() if stop_event.is_set(): break if averaging_period: current_time = get_dht_time() # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups) time_to_nearest_interval = max( 0.0, averaging_period - current_time % averaging_period) time.sleep(time_to_nearest_interval) with lock_parameters, averager.get_tensors() as averaged_tensors: local_tensors = tuple(p for group in opt.param_groups for p in group['params']) assert len(local_tensors) == len( averaged_tensors ), "The number of optimized parameters should not change." for local_tensor, averaged_tensor in zip( local_tensors, averaged_tensors): averaged_tensor[...] = local_tensor.cpu().float() try: if verbose: logger.info( f"Starting a new averaging round with current parameters." ) group_info = averager.step(**kwargs) if group_info is not None: with lock_parameters, averager.get_tensors( ) as averaged_tensors: for local_tensor, averaged_tensor in zip( local_tensors, averaged_tensors): local_tensor[...] = averaged_tensor.to( dtype=local_tensor.dtype) if verbose: logger.info( f"Finished averaging round in with {len(group_info)} peers." ) else: if verbose: logger.warning( f"Averaging round failed: could not find group.") except Exception as e: logger.error(f"Averaging round failed: caught {e}.")
async def _get_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] ) -> List[Optional[RemoteExpert]]: if expiration_time is None: expiration_time = get_dht_time() num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers) found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) experts: List[Optional[RemoteExpert]] = [None] * len(uids) for i, uid in enumerate(uids): if found[uid] is not None and isinstance(found[uid].value, Endpoint): experts[i] = RemoteExpert(uid, found[uid].value) return experts
def check_collaboration_state_periodically(self): """ Periodically check the training progress from all peers. Trigger update after target_batch_size total samples """ while self.is_alive(): time_to_next_update = max( 0.0, self.collaboration_state.next_fetch_time - get_dht_time()) if self.collaboration_state_updated.wait(time_to_next_update): self.collaboration_state_updated.clear() continue # if state was updated externally, reset timer with self.lock_collaboration_state: self.collaboration_state = self.fetch_collaboration_state()
def fetch_collaboration_state(self) -> CollaborationState: """ Read performance statistics reported by peers, estimate progress towards next batch """ response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float('inf')) current_time = get_dht_time() if not isinstance(response, dict) or len(response) == 0: logger.log(self.status_loglevel, f"Found no active peers: {response}") local_eta_next_step = max(0, self.target_batch_size - self.local_steps_accumulated ) / self.performance_ema.samples_per_second return CollaborationState(self.local_step, self.local_samples_accumulated, self.target_batch_size, num_peers=0, num_clients=0, eta_next_step=current_time + local_eta_next_step, next_fetch_time=current_time + self.default_refresh_period) valid_peer_states = [TrainingState.parse_obj(peer_state.value) for peer_state in response.values() if peer_state.value is not None] num_peers = len(valid_peer_states) num_clients = sum(state.client_mode for state in valid_peer_states) global_optimizer_step = self.local_step for state in valid_peer_states: if not state.client_mode: global_optimizer_step = max(global_optimizer_step, state.step) total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0 for state in valid_peer_states: total_samples_per_second += state.samples_per_second if state.step == global_optimizer_step: total_samples_accumulated += state.samples_accumulated estimated_current_samples += (state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second) # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance; # the rationale behind this is that outdated peers will synchronize and begin contributing shortly. estimated_samples_remaining = self.target_batch_size - estimated_current_samples estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate)) time_to_next_fetch = float(np.clip(a=estimated_time_to_next_step * num_peers / expected_max_peers, a_min=self.min_refresh_period, a_max=self.max_refresh_period)) logger.log(self.status_loglevel, f"Collaboration accumulated {total_samples_accumulated} samples from " f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds " f"(refresh in {time_to_next_fetch:.2f}s.)") return CollaborationState( global_optimizer_step, total_samples_accumulated, target_batch_size=self.target_batch_size, num_peers=num_peers, num_clients=num_clients, eta_next_step=current_time + estimated_time_to_next_step, next_fetch_time=current_time + time_to_next_fetch)
def update(self, num_processed: int) -> float: """ :param num_processed: how many items were processed since last call :returns: current estimate of performance (samples per second), but at most """ assert not self.paused, "PerformanceEMA is currently paused" assert num_processed > 0, f"Can't register processing {num_processed} samples" self.timestamp, old_timestamp = get_dht_time(), self.timestamp seconds_per_sample = max( 0, self.timestamp - old_timestamp) / num_processed self.ema_seconds_per_sample = self.alpha * seconds_per_sample + ( 1 - self.alpha) * self.ema_seconds_per_sample self.num_updates += 1 adjusted_seconds_per_sample = self.ema_seconds_per_sample / ( 1 - (1 - self.alpha)**self.num_updates) self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps) return self.samples_per_second
def report_training_progress(self): """ Periodically publish metadata and the current number of samples accumulated towards the next step """ while self.is_alive(): self.should_report_progress.wait() self.should_report_progress.clear() with self.lock_local_progress: current_time = get_dht_time() local_state_info = TrainingState( endpoint=self.averager.endpoint, step=self.local_step, samples_accumulated=self.local_samples_accumulated, samples_per_second=self.performance_ema.samples_per_second, time=current_time, client_mode=not self.averager.listen) self.dht.store(key=self.training_progress_key, subkey=self._local_public_key, value=local_state_info.dict(), expiration_time=current_time + self.metadata_expiration, return_future=True)
def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> bool: """ Add node to KBucket or update existing node, return True if successful, False if the bucket is full. If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper. :param node_id: dht node identifier that should be added or moved to the front of bucket :param endpoint: network address associated with that node id :note: this function has a side-effect of resetting KBucket.last_updated time """ if node_id in self.nodes_requested_for_ping: self.nodes_requested_for_ping.remove(node_id) self.last_updated = get_dht_time() if node_id in self.nodes_to_endpoint: del self.nodes_to_endpoint[node_id] self.nodes_to_endpoint[node_id] = endpoint elif len(self.nodes_to_endpoint) < self.size: self.nodes_to_endpoint[node_id] = endpoint else: if node_id in self.replacement_nodes: del self.replacement_nodes[node_id] self.replacement_nodes[node_id] = endpoint return False return True
async def _get_active_successors( dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int], negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]: grid_size = grid_size or float('inf') num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes)) dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers) successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {} for prefix, found in dht_responses.items(): if found and isinstance(found.value, dict): successors[prefix] = { coord: UidEndpoint(*match.value) for coord, match in found.value.items() if isinstance(coord, Coordinate) and 0 <= coord < grid_size and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2 } else: successors[prefix] = {} if found is None and negative_caching: logger.debug( f"DHT negative caching: storing a 'no prefix' entry for {prefix}" ) asyncio.create_task( node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)) return successors
def __init__(self, alpha: float = 0.1, eps: float = 1e-20): self.alpha, self.eps, self.num_updates = alpha, eps, 0 self.ema_seconds_per_sample, self.samples_per_second = 0, eps self.timestamp = get_dht_time() self.paused = False
def step(self, batch_size: Optional[int] = None, **kwargs): """ Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters :param batch_size: optional override for batch_size_per_step from init :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details. """ if self.batch_size_per_step is None: if batch_size is None: raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step") logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}") self.batch_size_per_step = batch_size batch_size = batch_size if batch_size is not None else self.batch_size_per_step if not self.is_synchronized: logger.log(self.status_loglevel, "Peer is out of sync.") self.load_state_from_peers() return if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration: logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, " f"but metadata expired in {self.metadata_expiration} s.") self.accumulate_grads_(batch_size) with self.lock_local_progress: self.local_samples_accumulated += batch_size self.local_steps_accumulated += 1 self.performance_ema.update(num_processed=batch_size) self.should_report_progress.set() if not self.collaboration_state.ready_for_step: return logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}") self.collaboration_state = self.fetch_collaboration_state() self.collaboration_state_updated.set() if not self.is_synchronized: self.load_state_from_peers() return with self.performance_ema.pause(), self.lock_collaboration_state: # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated self.apply_accumulated_grads_(scale_by=1. / self.local_steps_accumulated) current_step, group_info = self.averager.local_step, None if self.collaboration_state.num_peers > 1: mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers weight = self.local_samples_accumulated / mean_samples_per_worker try: group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs) if group_info: logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers") except BaseException as e: logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.") else: logger.log(self.status_loglevel, f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).") self.opt.step() self.reset_accumulated_grads_() self.local_samples_accumulated = self.local_steps_accumulated = 0 self.collaboration_state.register_step(current_step + 1) self.averager.local_step = current_step + 1 self.collaboration_state_updated.set() self.update_scheduler() logger.log(self.status_loglevel, f"Optimizer step: done!") return group_info
def step(self, batch_size: Optional[int] = None, **kwargs): """ Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters :param batch_size: optional override for batch_size_per_step from init :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details. """ if batch_size is not None and self.batch_size_per_step is None: raise ValueError( "Please either set batch_size_per_step parameter at init or provide batch_size in .step" ) batch_size = self.batch_size_per_step if batch_size is None else batch_size if not self.is_synchronized: self.load_state_from_peers() return if self.last_step_time is not None and get_dht_time( ) - self.last_step_time > self.metadata_expiration: logger.warning( f"Training step took {get_dht_time() - self.last_step_time}, " f"but metadata expired in {self.metadata_expiration} s.") with self.lock_local_progress: self.local_samples_accumulated += batch_size self.local_steps_accumulated += 1 self.performance_ema.update(num_processed=self.batch_size_per_step) self.should_report_progress.set() if not self.collaboration_state.ready_for_step: return logger.log(self.status_loglevel, "Averaging parameters and gradients with peers...") self.collaboration_state = self.fetch_collaboration_state() self.collaboration_state_updated.set() if not self.is_synchronized: self.load_state_from_peers() return with self.performance_ema.pause(), self.lock_collaboration_state: if self.collaboration_state.num_peers > 1: mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers weight = self.local_samples_accumulated / mean_samples_per_worker output = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs) else: logger.log( self.status_loglevel, f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).") output = None self.averager.local_step += 1 self.opt.step() self.opt.zero_grad() self.local_samples_accumulated = self.local_steps_accumulated = 0 self.collaboration_state.register_step() self.collaboration_state_updated.set() self.update_scheduler() logger.log(self.status_loglevel, f"Optimizer step: done!") return output
def ready_for_step(self): return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step