def __init__(self, model, byteps_opt, num_steps=10**6): """Construct a new ScheduledOptimizer, which uses byteps optimizer under the hood for averaging gradients across all workers. Args: model: The training model. ByteScheduler uses the model object to register hooks. byteps_opt: Optimizer to use for averaging gradients and applying updates. num_steps: The maximum number of training steps. ByteScheduler needs to know when to stop cross-iteration scheduling. """ self._model = model self._opt = byteps_opt self._logger = logging.getLogger("ByteScheduler") self._logger.debug("byteps size {}, rank {}".format(size(), rank())) self._desc = "rank {}".format(rank()) # Track training steps self._step = 0 self._final_step = num_steps # Use lock to block the forward propagation of each parameter. self._locks = {} for param_group in self.param_groups: for p in param_group['params']: self._locks[p] = threading.Lock() if size() > 1: self._register_forward_hooks() self._register_hooks() # Poll whether the tensor push-pull is finished. self._event_queue = queue.Queue() self._poller = threading.Thread(target=self._poll, args=()) self._poller.start()
def broadcast_parameters(params, root_rank): """ Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the `model.state_dict()`, `model.named_parameters()`, or `model.parameters()`. Arguments: params: One of the following: - list of parameters to broadcast - dict of parameters to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ if isinstance(params, dict): params = sorted(params.items()) elif isinstance(params, list): # support both named_parameters() and regular parameters() params = [p if isinstance(p, tuple) else (None, p) for p in params] else: raise ValueError('invalid params of type: %s' % type(params)) # Run synchronous broadcasts. for name, p in params: # Broadcast is implemented as push + pull in BytePS # To make it a real broadcast, we set the non-root tensors all 0. if rank() != root_rank: p.fill_(0) # Remember to disable averaging because we are doing broadcast handle = byteps_push_pull(p, average=False, name="Parameter." + name) synchronize(handle)