示例#1
0
    async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        while not future.done():
            try:
                self._pending_group_assembled.clear()
                gather_binary = self.serializer.dumps(gather)
                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                if allreduce_group is None:
                    raise AllreduceException("Averaging step failed: could not find a group.")

                group_id = allreduce_group.group_id
                self._running_groups[group_id] = allreduce_group
                self._pending_group_assembled.set()
                await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
                await loop.run_in_executor(None, self.update_tensors, allreduce_group)

                # averaging is finished, exit the loop
                gathered_items = map(self.serializer.loads, allreduce_group.gathered)
                gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                future.set_result(gathered_data_by_peer)

            except (AllreduceException, MatchmakingException):
                time_elapsed = get_dht_time() - start_time
                if not allow_retries or (timeout is not None and timeout < time_elapsed):
                    future.set_result(None)

            except Exception as e:
                future.set_exception(e)
                raise
            finally:
                _ = self._running_groups.pop(group_id, None)
                self._pending_group_assembled.set()
示例#2
0
    async def call_ping(self,
                        peer: Endpoint,
                        validate: bool = False,
                        strict: bool = True) -> Optional[DHTID]:
        """
        Get peer's node id and add him to the routing table. If peer doesn't respond, return None
        :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
        :param validate: if True, validates that node's endpoint is available
        :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
        :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table

        :return: node's DHTID, if peer responded and decided to send his node_id
        """
        try:
            async with self.rpc_semaphore:
                ping_request = dht_pb2.PingRequest(peer=self.node_info,
                                                   validate=validate)
                time_requested = get_dht_time()
                response = await self._get_dht_stub(peer).rpc_ping(
                    ping_request, timeout=self.wait_timeout)
                time_responded = get_dht_time()
        except grpc.aio.AioRpcError as error:
            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
            response = None
        responded = bool(response and response.peer and response.peer.node_id)

        if responded and validate:
            try:
                if self.server is not None and not response.available:
                    raise ValidationError(
                        f"Peer {peer} couldn't access this node at {response.sender_endpoint} . "
                        f"Make sure that this port is open for incoming requests."
                    )

                if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
                    if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \
                            response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS:
                        raise ValidationError(
                            f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds "
                            f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})"
                        )
            except ValidationError as e:
                if strict:
                    raise
                else:
                    logger.warning(repr(e))

        peer_id = DHTID.from_bytes(
            response.peer.node_id) if responded else None
        asyncio.create_task(
            self.update_routing_table(peer_id, peer, responded=responded))
        return peer_id
示例#3
0
    async def notify_stragglers_on_success(self):
        """ Find averagers that have fewer nbits and redirect them to your current nbits """
        for nbits in reversed(range(1, len(self.group_bits) - 1)):
            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)

            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
                break

        root_data = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True)
        if root_data is None or self.RESERVED_KEY_FOR_NBITS not in root_data.value:
            await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
示例#4
0
    async def notify_stragglers(self):
        """ Find averagers that have fewer nbits and redirect them to your current nbits """
        for nbits in reversed(range(1, len(self.group_bits) - 1)):
            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)

            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
                break

        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
        if isinstance(root_data, dict) and root_data.get(
                self.RESERVED_KEY_FOR_NBITS, (None, -float('inf')))[1] > get_dht_time() + self.nbits_grace_period:
            return
        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
示例#5
0
    async def rpc_ping(self, request: dht_pb2.PingRequest,
                       context: grpc.ServicerContext):
        """ Some node wants us to add it to our routing table. """
        response = dht_pb2.PingResponse(peer=self.node_info,
                                        sender_endpoint=context.peer(),
                                        dht_time=get_dht_time(),
                                        available=False)

        if request.peer and request.peer.node_id and request.peer.rpc_port:
            sender_id = DHTID.from_bytes(request.peer.node_id)
            if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
                sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
            else:
                sender_endpoint = replace_port(context.peer(),
                                               new_port=request.peer.rpc_port)

            response.sender_endpoint = sender_endpoint
            if request.validate:
                response.available = await self.call_ping(
                    response.sender_endpoint, validate=False) == sender_id

            asyncio.create_task(
                self.update_routing_table(sender_id,
                                          sender_endpoint,
                                          responded=response.available
                                          or not request.validate))

        return response
示例#6
0
    def _average_parameters_in_background(
            lock_parameters: Lock, update_event: Event, stop_event: Event, averager: TrainingAverager,
            averaging_period: float, verbose: bool, **kwargs):
        """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """
        while not stop_event.is_set():
            update_event.wait()
            update_event.clear()
            if stop_event.is_set():
                break

            if averaging_period:
                current_time = get_dht_time()
                # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups)
                time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
                time.sleep(time_to_nearest_interval)

            if verbose:
                logger.info(f"Starting a new averaging round with current parameters.")
            try:
                group_info = averager.step(lock_parameters, **kwargs)
                if verbose:
                    if group_info is not None:
                        logger.info(f"Finished averaging round in with {len(group_info)} peers.")
                    else:
                        logger.warning(f"Averaging round failed: could not find group.")
            except Exception as e:
                logger.error(f"Averaging round failed: caught {e}.")
示例#7
0
 def shutdown(self):
     logger.debug("Shutting down averager...")
     self.averager.shutdown()
     logger.debug("Sending goodbye to peers...")
     self.dht.store(self.training_progress_key, subkey=self._local_public_key, value=None,
                    expiration_time=get_dht_time() + self.metadata_expiration)
     logger.debug(f"{self.__class__.__name__} is shut down.")
示例#8
0
 def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
     assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
     self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
     self.nodes_to_endpoint: Dict[DHTID, Endpoint] = {}
     self.replacement_nodes: Dict[DHTID, Endpoint] = {}
     self.nodes_requested_for_ping: Set[DHTID] = set()
     self.last_updated = get_dht_time()
示例#9
0
    async def _get_initial_beam(
        dht: DHT,
        node: DHTNode,
        prefix: ExpertPrefix,
        beam_size: int,
        scores: Tuple[float, ...],
        negative_caching: bool,
        num_workers: Optional[int] = None
    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
        num_workers = num_workers or dht.max_workers or beam_size
        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate,
                                                   UidEndpoint]]] = []
        unattempted_indices: List[Coordinate] = sorted(
            range(len(scores)), key=scores.__getitem__)  # from worst to best
        pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix,
                                   asyncio.Task]] = deque()

        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
            # dispatch additional tasks
            while unattempted_indices and len(pending_tasks) < num_workers:
                next_index = unattempted_indices.pop(
                )  # note: this is best unattempted index because of sort order
                next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
                pending_tasks.append(
                    (next_index, next_best_prefix,
                     asyncio.create_task(node.get(next_best_prefix))))

            # await the next best prefix to be fetched
            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft(
            )
            try:
                maybe_prefix_data = await pending_task
                if maybe_prefix_data is not None and isinstance(
                        maybe_prefix_data.value, dict):
                    successors = {
                        coord: UidEndpoint(*match.value)
                        for coord, match in maybe_prefix_data.value.items()
                        if isinstance(coord, Coordinate)
                        and isinstance(getattr(match, 'value', None), list)
                        and len(match.value) == 2
                    }
                    if successors:
                        beam.append((scores[pending_best_index],
                                     pending_best_prefix, successors))
                elif maybe_prefix_data is None and negative_caching:
                    logger.debug(
                        f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}"
                    )
                    asyncio.create_task(
                        node.store(pending_best_prefix,
                                   subkey=-1,
                                   value=None,
                                   expiration_time=get_dht_time() +
                                   dht.default_expiration))

            except asyncio.CancelledError:
                for _, pending_task in pending_tasks:
                    pending_task.cancel()
                raise
        return beam
示例#10
0
async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID],
                           endpoint: Endpoint,
                           expiration: DHTExpiration) -> Dict[ExpertUID, bool]:
    num_workers = len(uids) if dht.max_workers is None else min(
        len(uids), dht.max_workers)
    expiration_time = get_dht_time() + expiration
    data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]],
                        DHTValue] = {}
    for uid in uids:
        data_to_store[uid, None] = endpoint
        prefix = uid if uid.count(
            UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}'
        for i in range(prefix.count(UID_DELIMITER) - 1):
            prefix, last_coord = split_uid(prefix)
            data_to_store[prefix, last_coord] = [uid, endpoint]

    keys, maybe_subkeys, values = zip(
        *((key, subkey, value)
          for (key, subkey), value in data_to_store.items()))
    store_ok = await node.store_many(keys,
                                     values,
                                     expiration_time,
                                     subkeys=maybe_subkeys,
                                     num_workers=num_workers)
    return store_ok
示例#11
0
 def pause(self):
     """ While inside this context, EMA will not count the time passed towards the performance estimate """
     self.paused, was_paused = True, self.paused
     try:
         yield
     finally:
         self.timestamp = get_dht_time()
         self.paused = was_paused
示例#12
0
    def _average_parameters_in_background(
            lock_parameters: Lock, update_event: Event, stop_event: Event,
            averager: DecentralizedAverager, opt: torch.optim.Optimizer,
            averaging_period: float, verbose: bool, **kwargs):
        """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """
        while not stop_event.is_set():
            update_event.wait()
            update_event.clear()
            if stop_event.is_set():
                break

            if averaging_period:
                current_time = get_dht_time()
                # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups)
                time_to_nearest_interval = max(
                    0.0, averaging_period - current_time % averaging_period)
                time.sleep(time_to_nearest_interval)

            with lock_parameters, averager.get_tensors() as averaged_tensors:
                local_tensors = tuple(p for group in opt.param_groups
                                      for p in group['params'])
                assert len(local_tensors) == len(
                    averaged_tensors
                ), "The number of optimized parameters should not change."

                for local_tensor, averaged_tensor in zip(
                        local_tensors, averaged_tensors):
                    averaged_tensor[...] = local_tensor.cpu().float()

            try:
                if verbose:
                    logger.info(
                        f"Starting a new averaging round with current parameters."
                    )
                group_info = averager.step(**kwargs)

                if group_info is not None:
                    with lock_parameters, averager.get_tensors(
                    ) as averaged_tensors:
                        for local_tensor, averaged_tensor in zip(
                                local_tensors, averaged_tensors):
                            local_tensor[...] = averaged_tensor.to(
                                dtype=local_tensor.dtype)
                    if verbose:
                        logger.info(
                            f"Finished averaging round in with {len(group_info)} peers."
                        )
                else:
                    if verbose:
                        logger.warning(
                            f"Averaging round failed: could not find group.")
            except Exception as e:
                logger.error(f"Averaging round failed: caught {e}.")
示例#13
0
async def _get_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
                       ) -> List[Optional[RemoteExpert]]:
    if expiration_time is None:
        expiration_time = get_dht_time()
    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
    found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)

    experts: List[Optional[RemoteExpert]] = [None] * len(uids)
    for i, uid in enumerate(uids):
        if found[uid] is not None and isinstance(found[uid].value, Endpoint):
            experts[i] = RemoteExpert(uid, found[uid].value)
    return experts
示例#14
0
    def check_collaboration_state_periodically(self):
        """
        Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
        """
        while self.is_alive():
            time_to_next_update = max(
                0.0, self.collaboration_state.next_fetch_time - get_dht_time())
            if self.collaboration_state_updated.wait(time_to_next_update):
                self.collaboration_state_updated.clear()
                continue  # if state was updated externally, reset timer

            with self.lock_collaboration_state:
                self.collaboration_state = self.fetch_collaboration_state()
示例#15
0
    def fetch_collaboration_state(self) -> CollaborationState:
        """ Read performance statistics reported by peers, estimate progress towards next batch """
        response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float('inf'))
        current_time = get_dht_time()

        if not isinstance(response, dict) or len(response) == 0:
            logger.log(self.status_loglevel, f"Found no active peers: {response}")
            local_eta_next_step = max(0, self.target_batch_size - self.local_steps_accumulated
                                      ) / self.performance_ema.samples_per_second
            return CollaborationState(self.local_step, self.local_samples_accumulated, self.target_batch_size,
                                      num_peers=0, num_clients=0, eta_next_step=current_time + local_eta_next_step,
                                      next_fetch_time=current_time + self.default_refresh_period)

        valid_peer_states = [TrainingState.parse_obj(peer_state.value)
                             for peer_state in response.values()
                             if peer_state.value is not None]

        num_peers = len(valid_peer_states)
        num_clients = sum(state.client_mode for state in valid_peer_states)
        global_optimizer_step = self.local_step
        for state in valid_peer_states:
            if not state.client_mode:
                global_optimizer_step = max(global_optimizer_step, state.step)

        total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0

        for state in valid_peer_states:
            total_samples_per_second += state.samples_per_second
            if state.step == global_optimizer_step:
                total_samples_accumulated += state.samples_accumulated
                estimated_current_samples += (state.samples_accumulated +
                                              max(0, current_time - state.time) * state.samples_per_second)
            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.

        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
        estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second

        expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
        time_to_next_fetch = float(np.clip(a=estimated_time_to_next_step * num_peers / expected_max_peers,
                                           a_min=self.min_refresh_period, a_max=self.max_refresh_period))
        logger.log(self.status_loglevel, f"Collaboration accumulated {total_samples_accumulated} samples from "
                                         f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
                                         f"(refresh in {time_to_next_fetch:.2f}s.)")
        return CollaborationState(
            global_optimizer_step, total_samples_accumulated, target_batch_size=self.target_batch_size,
            num_peers=num_peers, num_clients=num_clients, eta_next_step=current_time + estimated_time_to_next_step,
            next_fetch_time=current_time + time_to_next_fetch)
示例#16
0
 def update(self, num_processed: int) -> float:
     """
     :param num_processed: how many items were processed since last call
     :returns: current estimate of performance (samples per second), but at most
     """
     assert not self.paused, "PerformanceEMA is currently paused"
     assert num_processed > 0, f"Can't register processing {num_processed} samples"
     self.timestamp, old_timestamp = get_dht_time(), self.timestamp
     seconds_per_sample = max(
         0, self.timestamp - old_timestamp) / num_processed
     self.ema_seconds_per_sample = self.alpha * seconds_per_sample + (
         1 - self.alpha) * self.ema_seconds_per_sample
     self.num_updates += 1
     adjusted_seconds_per_sample = self.ema_seconds_per_sample / (
         1 - (1 - self.alpha)**self.num_updates)
     self.samples_per_second = 1 / max(adjusted_seconds_per_sample,
                                       self.eps)
     return self.samples_per_second
示例#17
0
    def report_training_progress(self):
        """ Periodically publish metadata and the current number of samples accumulated towards the next step """
        while self.is_alive():
            self.should_report_progress.wait()
            self.should_report_progress.clear()
            with self.lock_local_progress:
                current_time = get_dht_time()
                local_state_info = TrainingState(
                    endpoint=self.averager.endpoint,
                    step=self.local_step,
                    samples_accumulated=self.local_samples_accumulated,
                    samples_per_second=self.performance_ema.samples_per_second,
                    time=current_time,
                    client_mode=not self.averager.listen)

            self.dht.store(key=self.training_progress_key, subkey=self._local_public_key,
                           value=local_state_info.dict(),
                           expiration_time=current_time + self.metadata_expiration,
                           return_future=True)
示例#18
0
    def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> bool:
        """
        Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
        If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.

        :param node_id: dht node identifier that should be added or moved to the front of bucket
        :param endpoint: network address associated with that node id
        :note: this function has a side-effect of resetting KBucket.last_updated time
        """
        if node_id in self.nodes_requested_for_ping:
            self.nodes_requested_for_ping.remove(node_id)
        self.last_updated = get_dht_time()
        if node_id in self.nodes_to_endpoint:
            del self.nodes_to_endpoint[node_id]
            self.nodes_to_endpoint[node_id] = endpoint
        elif len(self.nodes_to_endpoint) < self.size:
            self.nodes_to_endpoint[node_id] = endpoint
        else:
            if node_id in self.replacement_nodes:
                del self.replacement_nodes[node_id]
            self.replacement_nodes[node_id] = endpoint
            return False
        return True
示例#19
0
 async def _get_active_successors(
     dht: DHT,
     node: DHTNode,
     prefixes: List[ExpertPrefix],
     grid_size: Optional[int],
     negative_caching: bool,
     cache_expiration: DHTExpiration,
     num_workers: Optional[int] = None
 ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
     grid_size = grid_size or float('inf')
     num_workers = num_workers or min(len(prefixes), dht.max_workers
                                      or len(prefixes))
     dht_responses = await node.get_many(keys=prefixes,
                                         num_workers=num_workers)
     successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
     for prefix, found in dht_responses.items():
         if found and isinstance(found.value, dict):
             successors[prefix] = {
                 coord: UidEndpoint(*match.value)
                 for coord, match in found.value.items()
                 if isinstance(coord, Coordinate) and 0 <= coord < grid_size
                 and isinstance(getattr(match, 'value', None), list)
                 and len(match.value) == 2
             }
         else:
             successors[prefix] = {}
             if found is None and negative_caching:
                 logger.debug(
                     f"DHT negative caching: storing a 'no prefix' entry for {prefix}"
                 )
                 asyncio.create_task(
                     node.store(prefix,
                                subkey=-1,
                                value=None,
                                expiration_time=get_dht_time() +
                                cache_expiration))
     return successors
示例#20
0
 def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
     self.alpha, self.eps, self.num_updates = alpha, eps, 0
     self.ema_seconds_per_sample, self.samples_per_second = 0, eps
     self.timestamp = get_dht_time()
     self.paused = False
示例#21
0
    def step(self, batch_size: Optional[int] = None, **kwargs):
        """
        Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters

        :param batch_size: optional override for batch_size_per_step from init
        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
        """
        if self.batch_size_per_step is None:
            if batch_size is None:
                raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
            logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
            self.batch_size_per_step = batch_size
        batch_size = batch_size if batch_size is not None else self.batch_size_per_step

        if not self.is_synchronized:
            logger.log(self.status_loglevel, "Peer is out of sync.")
            self.load_state_from_peers()
            return

        if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
            logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
                           f"but metadata expired in {self.metadata_expiration} s.")

        self.accumulate_grads_(batch_size)

        with self.lock_local_progress:
            self.local_samples_accumulated += batch_size
            self.local_steps_accumulated += 1
            self.performance_ema.update(num_processed=batch_size)
            self.should_report_progress.set()

        if not self.collaboration_state.ready_for_step:
            return

        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
        self.collaboration_state = self.fetch_collaboration_state()
        self.collaboration_state_updated.set()

        if not self.is_synchronized:
            self.load_state_from_peers()
            return

        with self.performance_ema.pause(), self.lock_collaboration_state:
            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
            self.apply_accumulated_grads_(scale_by=1. / self.local_steps_accumulated)
            current_step, group_info = self.averager.local_step, None

            if self.collaboration_state.num_peers > 1:
                mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                weight = self.local_samples_accumulated / mean_samples_per_worker
                try:
                    group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
                    if group_info:
                        logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
                except BaseException as e:
                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")

            else:
                logger.log(self.status_loglevel, f"Skipped averaging: collaboration consists of "
                                                 f"{self.collaboration_state.num_peers} peer(s).")

            self.opt.step()
            self.reset_accumulated_grads_()
            self.local_samples_accumulated = self.local_steps_accumulated = 0
            self.collaboration_state.register_step(current_step + 1)
            self.averager.local_step = current_step + 1
            self.collaboration_state_updated.set()
            self.update_scheduler()

        logger.log(self.status_loglevel, f"Optimizer step: done!")

        return group_info
示例#22
0
    def step(self, batch_size: Optional[int] = None, **kwargs):
        """
        Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters

        :param batch_size: optional override for batch_size_per_step from init
        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
        """
        if batch_size is not None and self.batch_size_per_step is None:
            raise ValueError(
                "Please either set batch_size_per_step parameter at init or provide batch_size in .step"
            )
        batch_size = self.batch_size_per_step if batch_size is None else batch_size

        if not self.is_synchronized:
            self.load_state_from_peers()
            return

        if self.last_step_time is not None and get_dht_time(
        ) - self.last_step_time > self.metadata_expiration:
            logger.warning(
                f"Training step took {get_dht_time() - self.last_step_time}, "
                f"but metadata expired in {self.metadata_expiration} s.")

        with self.lock_local_progress:
            self.local_samples_accumulated += batch_size
            self.local_steps_accumulated += 1
            self.performance_ema.update(num_processed=self.batch_size_per_step)
            self.should_report_progress.set()

        if not self.collaboration_state.ready_for_step:
            return

        logger.log(self.status_loglevel,
                   "Averaging parameters and gradients with peers...")
        self.collaboration_state = self.fetch_collaboration_state()
        self.collaboration_state_updated.set()

        if not self.is_synchronized:
            self.load_state_from_peers()
            return

        with self.performance_ema.pause(), self.lock_collaboration_state:
            if self.collaboration_state.num_peers > 1:
                mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                weight = self.local_samples_accumulated / mean_samples_per_worker
                output = self.averager.step(weight=weight,
                                            timeout=self.averaging_timeout,
                                            **kwargs)
            else:
                logger.log(
                    self.status_loglevel,
                    f"Skipped averaging: collaboration consists of "
                    f"{self.collaboration_state.num_peers} peer(s).")
                output = None
                self.averager.local_step += 1

            self.opt.step()
            self.opt.zero_grad()
            self.local_samples_accumulated = self.local_steps_accumulated = 0
            self.collaboration_state.register_step()
            self.collaboration_state_updated.set()
            self.update_scheduler()

            logger.log(self.status_loglevel, f"Optimizer step: done!")
            return output
示例#23
0
 def ready_for_step(self):
     return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step