Beispiel #1
0
            loss = F.nll_loss(output, target)
            time.sleep(0.5)
            loss.backward()
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                epoch,
                batch_idx * len(data),
                len(train_loader.dataset),
                100.0 * batch_idx / len(train_loader),
                loss.item(),
            ))

            if ftlib.skip_allreduce():
                logging.info("skip allreduce")
                optimizer.step()
                continue
            else:
                res = ftlib.wait_gradients_ready(model)
            if res == FTAllReduceStatus.NO_NEED:
                logging.critical(
                    "cannot use average_gradient when there is no need")
                exit(2)
            if res == FTAllReduceStatus.SUCCESS:
                logging.info("average succeed")
                optimizer.step()
            if res == FTAllReduceStatus.ABORT:
                logging.info("average failed, abort")
                continue
        scheduler.step()

    logging.info("terminate!")
class CollectiveCommunicator(object):
    def __init__(self, service_name=None):
        if _FTLIB_INSTALLED:
            connection_try_num = 0
            while True:
                try:
                    peer_list = list(self._get_peer_set(service_name))
                except Exception:
                    if (connection_try_num * 5 >
                            _FTLIB_CONSENSUS_CONNECTION_TIMEOUT_SECS):
                        logger.error(
                            "Cannot connect to FTLib consensus service in %s "
                            "seconds",
                            str(_FTLIB_CONSENSUS_CONNECTION_TIMEOUT_SECS),
                        )
                        self._ftlib = None
                        return
                    # sleep for 5s and try again
                    logger.info("Cannot connect to FTLib consensus service, "
                                "trying again.")
                    connection_try_num += 1
                    time.sleep(5)
                else:
                    break

            self._ftlib = BasicFTLib(
                consensus="gossip",
                commlib="pytorch",
                consensus_init_kwargs={
                    "known_addr_list": peer_list,
                    "custom_bind_addr":
                    socket.gethostbyname(socket.gethostname()),
                },
            )
            while peer_list and not self._ftlib.consensus_joined():
                logger.warning("Retry building consensus...")
                self._ftlib.manual_join(
                    known_addr_list=list(self._get_peer_set(service_name)))
        else:
            logger.warning(
                "FTLib is not installed. The CollectiveCommunicator "
                "may not work as expected")
            self._ftlib = None

    def tf_allreduce(self, grads, op="MEAN"):
        if grads is None:
            logger.error("Grads is required for tf_allreduce operation")
            return CollectiveCommunicatorStatus.FAILED, grads
        # convert tf.Tensor to numpy
        numpy_data = [g.numpy() for g in grads]
        return self.allreduce(numpy_data, op)

    def allreduce(self, data, op="MEAN"):
        if data is None:
            logger.error("Data is required for allreduce operation")
            return CollectiveCommunicatorStatus.FAILED, data
        if op not in _SUPPORTED_ALLREDUCE_OPS:
            logger.error(
                "%s is not in list of supported allreduce operations: %s" %
                (op, _SUPPORTED_ALLREDUCE_OPS))
            return CollectiveCommunicatorStatus.FAILED, data
        if self._ftlib is not None:
            status, res = self._ftlib.wait_gradients_ready(params=data)
            if (status == FTCollectiveStatus.SUCCESS
                    and res == CommLibStatus.SUCCESS
                    or status == FTCollectiveStatus.NO_NEED):
                return CollectiveCommunicatorStatus.SUCCEEDED, data
            else:
                return CollectiveCommunicatorStatus.FAILED, data
        else:
            logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE)
            return CollectiveCommunicatorStatus.SUCCEEDED, data

    def tf_broadcast(self, params, src_rank):
        for p in params:
            data = p.numpy()
            status, data = self.broadcast(p.numpy(), src_rank)
            if status == CollectiveCommunicatorStatus.SUCCEEDED:
                p.assign(data)
            else:
                return status
        return CollectiveCommunicatorStatus.SUCCEEDED

    def broadcast(self, data, src_rank):
        if self._ftlib is not None:
            status, _ = self._ftlib.broadcast(data, src_rank)
            if status == FTCollectiveStatus.SUCCESS:
                return CollectiveCommunicatorStatus.SUCCEEDED, data
            else:
                return CollectiveCommunicatorStatus.FAILED, data
        else:
            logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE)
            return CollectiveCommunicatorStatus.SUCCEEDED, data

    def barrier(self):
        if self._ftlib is not None:
            status, _ = self._ftlib.barrier()
            if status == FTCollectiveStatus.SUCCESS:
                return CollectiveCommunicatorStatus.SUCCEEDED
            else:
                return CollectiveCommunicatorStatus.FAILED
        else:
            logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE)
            return CollectiveCommunicatorStatus.SUCCEEDED

    def is_initialized(self):
        """This will be `False` under three occasions:
           * New workers report joining in
           * Collective-communication operations fail or time out
           * Liveness probe fails for existing workers
        """
        if self._ftlib is not None:
            return self._ftlib.initialized
        else:
            return True

    def _get_peer_set(self, svc_name):
        if svc_name is None:
            return None
        my_ip = socket.gethostbyname(socket.gethostname())
        temp_set = socket.getaddrinfo(svc_name, 0, proto=socket.IPPROTO_TCP)
        peer_set = {peer[-1][0] for peer in temp_set if peer[-1][0] != my_ip}
        return peer_set
Beispiel #3
0
class CollectiveCommunicator(object):
    def __init__(self, service_name=None):
        if _FTLIB_INSTALLED:
            self._ftlib = BasicFTLib(
                consensus="gossip",
                commlib="pytorch",
                consensus_init_kwargs={
                    "known_addr_list": list(self._get_peer_set(service_name)),
                    "custom_bind_addr":
                    socket.gethostbyname(socket.gethostname()),
                },
            )
            while not self._ftlib.consensus_joined():
                logger.warning("Retry building consensus...")
                self._ftlib.manual_join(
                    known_addr_list=list(self._get_peer_set(service_name)))
        else:
            logger.warning(
                "FTLib is not installed. The CollectiveCommunicator "
                "may not work as expected")
            self._ftlib = None

    def allreduce(self, data, op="MEAN"):
        if data is None:
            logger.error("Data is required for allreduce operation")
            return CollectiveCommunicatorStatus.FAILED, data
        if op not in _SUPPORTED_ALLREDUCE_OPS:
            logger.error(
                "%s is not in list of supported allreduce operations: %s" %
                (op, _SUPPORTED_ALLREDUCE_OPS))
            return CollectiveCommunicatorStatus.FAILED, data
        if self._ftlib is not None:
            res = self._ftlib.wait_gradients_ready(data)
            if res == FTAllReduceStatus.SUCCESS:
                return CollectiveCommunicatorStatus.SUCCEEDED, data
            else:
                return CollectiveCommunicatorStatus.FAILED, data
        else:
            logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE)
            return CollectiveCommunicatorStatus.SUCCEEDED, data

    def broadcast(self, data, src_rank):
        if self._ftlib is not None:
            res = self._ftlib.broadcast(data, src_rank)
            if res == FTAllReduceStatus.SUCCESS:
                return CollectiveCommunicatorStatus.SUCCEEDED, data
            else:
                return CollectiveCommunicatorStatus.FAILED, data
        else:
            logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE)
            return CollectiveCommunicatorStatus.SUCCEEDED, data

    def barrier(self):
        if self._ftlib is not None:
            res = self._ftlib.barrier()
            if res == FTAllReduceStatus.SUCCESS:
                return CollectiveCommunicatorStatus.SUCCEEDED
            else:
                return CollectiveCommunicatorStatus.FAILED
        else:
            logger.warning(_FTLIB_UNINSTALLED_DEFAULT_STATUS_MESSAGE)
            return CollectiveCommunicatorStatus.SUCCEEDED

    def is_initialized(self):
        """This will be `False` under three occasions:
           * New workers report joining in
           * Collective-communication operations fail or time out
           * Liveness probe fails for existing workers
        """
        if self._ftlib is not None:
            return self._ftlib.initialized
        else:
            return True

    def _get_peer_set(self, svc_name):
        if svc_name is None:
            return None
        my_ip = socket.gethostbyname(socket.gethostname())
        temp_set = socket.getaddrinfo(svc_name, 0, proto=socket.IPPROTO_TCP)
        peer_set = {peer[-1][0] for peer in temp_set if peer[-1][0] != my_ip}
        return peer_set