loss = F.nll_loss(output, target) time.sleep(0.5) loss.backward() print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss.item(), )) if ftlib.skip_allreduce(): logging.info("skip allreduce") optimizer.step() continue else: res = ftlib.wait_gradients_ready(model) if res == FTAllReduceStatus.NO_NEED: logging.critical( "cannot use average_gradient when there is no need") exit(2) if res == FTAllReduceStatus.SUCCESS: logging.info("average succeed") optimizer.step() if res == FTAllReduceStatus.ABORT: logging.info("average failed, abort") continue scheduler.step() logging.info("terminate!")
class CollectiveCommunicator(object): def __init__(self, service_name=None): if _FTLIB_INSTALLED: connection_try_num = 0 while True: try: peer_list = list(self._get_peer_set(service_name)) except Exception: if (connection_try_num * 5 > _FTLIB_CONSENSUS_CONNECTION_TIMEOUT_SECS): logger.error( "Cannot connect to FTLib consensus service in %s " "seconds", str(_FTLIB_CONSENSUS_CONNECTION_TIMEOUT_SECS), ) self._ftlib = None return # sleep for 5s and try again logger.info("Cannot connect to FTLib consensus service, " "trying again.") connection_try_num += 1 time.sleep(5) else: break self._ftlib = BasicFTLib( consensus="gossip", commlib="pytorch", consensus_init_kwargs={ "known_addr_list": peer_list, "custom_bind_addr": socket.gethostbyname(socket.gethostname()), }, ) while peer_list and not self._ftlib.consensus_joined(): logger.warning("Retry building consensus...") self._ftlib.manual_join( known_addr_list=list(self._get_peer_set(service_name))) else: logger.warning( "FTLib is not installed. The CollectiveCommunicator " "may not work as expected") self._ftlib = None def tf_allreduce(self, grads, op="MEAN"): if grads is None: logger.error("Grads is required for tf_allreduce operation") return CollectiveCommunicatorStatus.FAILED, grads # convert tf.Tensor to numpy numpy_data = [g.numpy() for g in grads] return self.allreduce(numpy_data, op) def allreduce(self, data, op="MEAN"): if data is None: logger.error("Data is required for allreduce operation") return CollectiveCommunicatorStatus.FAILED, data if op not in _SUPPORTED_ALLREDUCE_OPS: logger.error( "%s is not in list of supported allreduce operations: %s" % (op, _SUPPORTED_ALLREDUCE_OPS)) return CollectiveCommunicatorStatus.FAILED, data if self._ftlib is not None: status, res = self._ftlib.wait_gradients_ready(params=data) if (status == FTCollectiveStatus.SUCCESS and res == CommLibStatus.SUCCESS or status == FTCollectiveStatus.NO_NEED): return CollectiveCommunicatorStatus.SUCCEEDED, data else: return CollectiveCommunicatorStatus.FAILED, data else: logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE) return CollectiveCommunicatorStatus.SUCCEEDED, data def tf_broadcast(self, params, src_rank): for p in params: data = p.numpy() status, data = self.broadcast(p.numpy(), src_rank) if status == CollectiveCommunicatorStatus.SUCCEEDED: p.assign(data) else: return status return CollectiveCommunicatorStatus.SUCCEEDED def broadcast(self, data, src_rank): if self._ftlib is not None: status, _ = self._ftlib.broadcast(data, src_rank) if status == FTCollectiveStatus.SUCCESS: return CollectiveCommunicatorStatus.SUCCEEDED, data else: return CollectiveCommunicatorStatus.FAILED, data else: logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE) return CollectiveCommunicatorStatus.SUCCEEDED, data def barrier(self): if self._ftlib is not None: status, _ = self._ftlib.barrier() if status == FTCollectiveStatus.SUCCESS: return CollectiveCommunicatorStatus.SUCCEEDED else: return CollectiveCommunicatorStatus.FAILED else: logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE) return CollectiveCommunicatorStatus.SUCCEEDED def is_initialized(self): """This will be `False` under three occasions: * New workers report joining in * Collective-communication operations fail or time out * Liveness probe fails for existing workers """ if self._ftlib is not None: return self._ftlib.initialized else: return True def _get_peer_set(self, svc_name): if svc_name is None: return None my_ip = socket.gethostbyname(socket.gethostname()) temp_set = socket.getaddrinfo(svc_name, 0, proto=socket.IPPROTO_TCP) peer_set = {peer[-1][0] for peer in temp_set if peer[-1][0] != my_ip} return peer_set
class CollectiveCommunicator(object): def __init__(self, service_name=None): if _FTLIB_INSTALLED: self._ftlib = BasicFTLib( consensus="gossip", commlib="pytorch", consensus_init_kwargs={ "known_addr_list": list(self._get_peer_set(service_name)), "custom_bind_addr": socket.gethostbyname(socket.gethostname()), }, ) while not self._ftlib.consensus_joined(): logger.warning("Retry building consensus...") self._ftlib.manual_join( known_addr_list=list(self._get_peer_set(service_name))) else: logger.warning( "FTLib is not installed. The CollectiveCommunicator " "may not work as expected") self._ftlib = None def allreduce(self, data, op="MEAN"): if data is None: logger.error("Data is required for allreduce operation") return CollectiveCommunicatorStatus.FAILED, data if op not in _SUPPORTED_ALLREDUCE_OPS: logger.error( "%s is not in list of supported allreduce operations: %s" % (op, _SUPPORTED_ALLREDUCE_OPS)) return CollectiveCommunicatorStatus.FAILED, data if self._ftlib is not None: res = self._ftlib.wait_gradients_ready(data) if res == FTAllReduceStatus.SUCCESS: return CollectiveCommunicatorStatus.SUCCEEDED, data else: return CollectiveCommunicatorStatus.FAILED, data else: logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE) return CollectiveCommunicatorStatus.SUCCEEDED, data def broadcast(self, data, src_rank): if self._ftlib is not None: res = self._ftlib.broadcast(data, src_rank) if res == FTAllReduceStatus.SUCCESS: return CollectiveCommunicatorStatus.SUCCEEDED, data else: return CollectiveCommunicatorStatus.FAILED, data else: logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE) return CollectiveCommunicatorStatus.SUCCEEDED, data def barrier(self): if self._ftlib is not None: res = self._ftlib.barrier() if res == FTAllReduceStatus.SUCCESS: return CollectiveCommunicatorStatus.SUCCEEDED else: return CollectiveCommunicatorStatus.FAILED else: logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE) return CollectiveCommunicatorStatus.SUCCEEDED def is_initialized(self): """This will be `False` under three occasions: * New workers report joining in * Collective-communication operations fail or time out * Liveness probe fails for existing workers """ if self._ftlib is not None: return self._ftlib.initialized else: return True def _get_peer_set(self, svc_name): if svc_name is None: return None my_ip = socket.gethostbyname(socket.gethostname()) temp_set = socket.getaddrinfo(svc_name, 0, proto=socket.IPPROTO_TCP) peer_set = {peer[-1][0] for peer in temp_set if peer[-1][0] != my_ip} return peer_set